]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/wrap_return_type_in_result.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / wrap_return_type_in_result.rs
1 use std::iter;
2
3 use ide_db::helpers::{for_each_tail_expr, node_ext::walk_expr, FamousDefs};
4 use syntax::{
5     ast::{self, make, Expr},
6     match_ast, AstNode,
7 };
8
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11 // Assist: wrap_return_type_in_result
12 //
13 // Wrap the function's return type into Result.
14 //
15 // ```
16 // # //- minicore: result
17 // fn foo() -> i32$0 { 42i32 }
18 // ```
19 // ->
20 // ```
21 // fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
22 // ```
23 pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
24     let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
25     let parent = ret_type.syntax().parent()?;
26     let body = match_ast! {
27         match parent {
28             ast::Fn(func) => func.body()?,
29             ast::ClosureExpr(closure) => match closure.body()? {
30                 Expr::BlockExpr(block) => block,
31                 // closures require a block when a return type is specified
32                 _ => return None,
33             },
34             _ => return None,
35         }
36     };
37
38     let type_ref = &ret_type.ty()?;
39     let ty = ctx.sema.resolve_type(type_ref).and_then(|ty| ty.as_adt());
40     let result_enum =
41         FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax()).krate()).core_result_Result()?;
42
43     if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
44         cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
45         return None;
46     }
47
48     acc.add(
49         AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
50         "Wrap return type in Result",
51         type_ref.syntax().text_range(),
52         |builder| {
53             let body = ast::Expr::BlockExpr(body);
54
55             let mut exprs_to_wrap = Vec::new();
56             let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e);
57             walk_expr(&body, &mut |expr| {
58                 if let Expr::ReturnExpr(ret_expr) = expr {
59                     if let Some(ret_expr_arg) = &ret_expr.expr() {
60                         for_each_tail_expr(ret_expr_arg, tail_cb);
61                     }
62                 }
63             });
64             for_each_tail_expr(&body, tail_cb);
65
66             for ret_expr_arg in exprs_to_wrap {
67                 let ok_wrapped = make::expr_call(
68                     make::expr_path(make::ext::ident_path("Ok")),
69                     make::arg_list(iter::once(ret_expr_arg.clone())),
70                 );
71                 builder.replace_ast(ret_expr_arg, ok_wrapped);
72             }
73
74             match ctx.config.snippet_cap {
75                 Some(cap) => {
76                     let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
77                     builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
78                 }
79                 None => builder
80                     .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
81             }
82         },
83     )
84 }
85
86 fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
87     match e {
88         Expr::BreakExpr(break_expr) => {
89             if let Some(break_expr_arg) = break_expr.expr() {
90                 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
91             }
92         }
93         Expr::ReturnExpr(ret_expr) => {
94             if let Some(ret_expr_arg) = &ret_expr.expr() {
95                 for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
96             }
97         }
98         e => acc.push(e.clone()),
99     }
100 }
101
102 #[cfg(test)]
103 mod tests {
104     use crate::tests::{check_assist, check_assist_not_applicable};
105
106     use super::*;
107
108     #[test]
109     fn wrap_return_type_in_result_simple() {
110         check_assist(
111             wrap_return_type_in_result,
112             r#"
113 //- minicore: result
114 fn foo() -> i3$02 {
115     let test = "test";
116     return 42i32;
117 }
118 "#,
119             r#"
120 fn foo() -> Result<i32, ${0:_}> {
121     let test = "test";
122     return Ok(42i32);
123 }
124 "#,
125         );
126     }
127
128     #[test]
129     fn wrap_return_type_break_split_tail() {
130         check_assist(
131             wrap_return_type_in_result,
132             r#"
133 //- minicore: result
134 fn foo() -> i3$02 {
135     loop {
136         break if true {
137             1
138         } else {
139             0
140         };
141     }
142 }
143 "#,
144             r#"
145 fn foo() -> Result<i32, ${0:_}> {
146     loop {
147         break if true {
148             Ok(1)
149         } else {
150             Ok(0)
151         };
152     }
153 }
154 "#,
155         );
156     }
157
158     #[test]
159     fn wrap_return_type_in_result_simple_closure() {
160         check_assist(
161             wrap_return_type_in_result,
162             r#"
163 //- minicore: result
164 fn foo() {
165     || -> i32$0 {
166         let test = "test";
167         return 42i32;
168     };
169 }
170 "#,
171             r#"
172 fn foo() {
173     || -> Result<i32, ${0:_}> {
174         let test = "test";
175         return Ok(42i32);
176     };
177 }
178 "#,
179         );
180     }
181
182     #[test]
183     fn wrap_return_type_in_result_simple_return_type_bad_cursor() {
184         check_assist_not_applicable(
185             wrap_return_type_in_result,
186             r#"
187 //- minicore: result
188 fn foo() -> i32 {
189     let test = "test";$0
190     return 42i32;
191 }
192 "#,
193         );
194     }
195
196     #[test]
197     fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() {
198         check_assist_not_applicable(
199             wrap_return_type_in_result,
200             r#"
201 //- minicore: result
202 fn foo() {
203     || -> i32 {
204         let test = "test";$0
205         return 42i32;
206     };
207 }
208 "#,
209         );
210     }
211
212     #[test]
213     fn wrap_return_type_in_result_closure_non_block() {
214         check_assist_not_applicable(
215             wrap_return_type_in_result,
216             r#"
217 //- minicore: result
218 fn foo() { || -> i$032 3; }
219 "#,
220         );
221     }
222
223     #[test]
224     fn wrap_return_type_in_result_simple_return_type_already_result_std() {
225         check_assist_not_applicable(
226             wrap_return_type_in_result,
227             r#"
228 //- minicore: result
229 fn foo() -> core::result::Result<i32$0, String> {
230     let test = "test";
231     return 42i32;
232 }
233 "#,
234         );
235     }
236
237     #[test]
238     fn wrap_return_type_in_result_simple_return_type_already_result() {
239         cov_mark::check!(wrap_return_type_in_result_simple_return_type_already_result);
240         check_assist_not_applicable(
241             wrap_return_type_in_result,
242             r#"
243 //- minicore: result
244 fn foo() -> Result<i32$0, String> {
245     let test = "test";
246     return 42i32;
247 }
248 "#,
249         );
250     }
251
252     #[test]
253     fn wrap_return_type_in_result_simple_return_type_already_result_closure() {
254         check_assist_not_applicable(
255             wrap_return_type_in_result,
256             r#"
257 //- minicore: result
258 fn foo() {
259     || -> Result<i32$0, String> {
260         let test = "test";
261         return 42i32;
262     };
263 }
264 "#,
265         );
266     }
267
268     #[test]
269     fn wrap_return_type_in_result_simple_with_cursor() {
270         check_assist(
271             wrap_return_type_in_result,
272             r#"
273 //- minicore: result
274 fn foo() -> $0i32 {
275     let test = "test";
276     return 42i32;
277 }
278 "#,
279             r#"
280 fn foo() -> Result<i32, ${0:_}> {
281     let test = "test";
282     return Ok(42i32);
283 }
284 "#,
285         );
286     }
287
288     #[test]
289     fn wrap_return_type_in_result_simple_with_tail() {
290         check_assist(
291             wrap_return_type_in_result,
292             r#"
293 //- minicore: result
294 fn foo() ->$0 i32 {
295     let test = "test";
296     42i32
297 }
298 "#,
299             r#"
300 fn foo() -> Result<i32, ${0:_}> {
301     let test = "test";
302     Ok(42i32)
303 }
304 "#,
305         );
306     }
307
308     #[test]
309     fn wrap_return_type_in_result_simple_with_tail_closure() {
310         check_assist(
311             wrap_return_type_in_result,
312             r#"
313 //- minicore: result
314 fn foo() {
315     || ->$0 i32 {
316         let test = "test";
317         42i32
318     };
319 }
320 "#,
321             r#"
322 fn foo() {
323     || -> Result<i32, ${0:_}> {
324         let test = "test";
325         Ok(42i32)
326     };
327 }
328 "#,
329         );
330     }
331
332     #[test]
333     fn wrap_return_type_in_result_simple_with_tail_only() {
334         check_assist(
335             wrap_return_type_in_result,
336             r#"
337 //- minicore: result
338 fn foo() -> i32$0 { 42i32 }
339 "#,
340             r#"
341 fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
342 "#,
343         );
344     }
345
346     #[test]
347     fn wrap_return_type_in_result_simple_with_tail_block_like() {
348         check_assist(
349             wrap_return_type_in_result,
350             r#"
351 //- minicore: result
352 fn foo() -> i32$0 {
353     if true {
354         42i32
355     } else {
356         24i32
357     }
358 }
359 "#,
360             r#"
361 fn foo() -> Result<i32, ${0:_}> {
362     if true {
363         Ok(42i32)
364     } else {
365         Ok(24i32)
366     }
367 }
368 "#,
369         );
370     }
371
372     #[test]
373     fn wrap_return_type_in_result_simple_without_block_closure() {
374         check_assist(
375             wrap_return_type_in_result,
376             r#"
377 //- minicore: result
378 fn foo() {
379     || -> i32$0 {
380         if true {
381             42i32
382         } else {
383             24i32
384         }
385     };
386 }
387 "#,
388             r#"
389 fn foo() {
390     || -> Result<i32, ${0:_}> {
391         if true {
392             Ok(42i32)
393         } else {
394             Ok(24i32)
395         }
396     };
397 }
398 "#,
399         );
400     }
401
402     #[test]
403     fn wrap_return_type_in_result_simple_with_nested_if() {
404         check_assist(
405             wrap_return_type_in_result,
406             r#"
407 //- minicore: result
408 fn foo() -> i32$0 {
409     if true {
410         if false {
411             1
412         } else {
413             2
414         }
415     } else {
416         24i32
417     }
418 }
419 "#,
420             r#"
421 fn foo() -> Result<i32, ${0:_}> {
422     if true {
423         if false {
424             Ok(1)
425         } else {
426             Ok(2)
427         }
428     } else {
429         Ok(24i32)
430     }
431 }
432 "#,
433         );
434     }
435
436     #[test]
437     fn wrap_return_type_in_result_simple_with_await() {
438         check_assist(
439             wrap_return_type_in_result,
440             r#"
441 //- minicore: result
442 async fn foo() -> i$032 {
443     if true {
444         if false {
445             1.await
446         } else {
447             2.await
448         }
449     } else {
450         24i32.await
451     }
452 }
453 "#,
454             r#"
455 async fn foo() -> Result<i32, ${0:_}> {
456     if true {
457         if false {
458             Ok(1.await)
459         } else {
460             Ok(2.await)
461         }
462     } else {
463         Ok(24i32.await)
464     }
465 }
466 "#,
467         );
468     }
469
470     #[test]
471     fn wrap_return_type_in_result_simple_with_array() {
472         check_assist(
473             wrap_return_type_in_result,
474             r#"
475 //- minicore: result
476 fn foo() -> [i32;$0 3] { [1, 2, 3] }
477 "#,
478             r#"
479 fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }
480 "#,
481         );
482     }
483
484     #[test]
485     fn wrap_return_type_in_result_simple_with_cast() {
486         check_assist(
487             wrap_return_type_in_result,
488             r#"
489 //- minicore: result
490 fn foo() -$0> i32 {
491     if true {
492         if false {
493             1 as i32
494         } else {
495             2 as i32
496         }
497     } else {
498         24 as i32
499     }
500 }
501 "#,
502             r#"
503 fn foo() -> Result<i32, ${0:_}> {
504     if true {
505         if false {
506             Ok(1 as i32)
507         } else {
508             Ok(2 as i32)
509         }
510     } else {
511         Ok(24 as i32)
512     }
513 }
514 "#,
515         );
516     }
517
518     #[test]
519     fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
520         check_assist(
521             wrap_return_type_in_result,
522             r#"
523 //- minicore: result
524 fn foo() -> i32$0 {
525     let my_var = 5;
526     match my_var {
527         5 => 42i32,
528         _ => 24i32,
529     }
530 }
531 "#,
532             r#"
533 fn foo() -> Result<i32, ${0:_}> {
534     let my_var = 5;
535     match my_var {
536         5 => Ok(42i32),
537         _ => Ok(24i32),
538     }
539 }
540 "#,
541         );
542     }
543
544     #[test]
545     fn wrap_return_type_in_result_simple_with_loop_with_tail() {
546         check_assist(
547             wrap_return_type_in_result,
548             r#"
549 //- minicore: result
550 fn foo() -> i32$0 {
551     let my_var = 5;
552     loop {
553         println!("test");
554         5
555     }
556     my_var
557 }
558 "#,
559             r#"
560 fn foo() -> Result<i32, ${0:_}> {
561     let my_var = 5;
562     loop {
563         println!("test");
564         5
565     }
566     Ok(my_var)
567 }
568 "#,
569         );
570     }
571
572     #[test]
573     fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
574         check_assist(
575             wrap_return_type_in_result,
576             r#"
577 //- minicore: result
578 fn foo() -> i32$0 {
579     let my_var = let x = loop {
580         break 1;
581     };
582     my_var
583 }
584 "#,
585             r#"
586 fn foo() -> Result<i32, ${0:_}> {
587     let my_var = let x = loop {
588         break 1;
589     };
590     Ok(my_var)
591 }
592 "#,
593         );
594     }
595
596     #[test]
597     fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
598         check_assist(
599             wrap_return_type_in_result,
600             r#"
601 //- minicore: result
602 fn foo() -> i32$0 {
603     let my_var = 5;
604     let res = match my_var {
605         5 => 42i32,
606         _ => return 24i32,
607     };
608     res
609 }
610 "#,
611             r#"
612 fn foo() -> Result<i32, ${0:_}> {
613     let my_var = 5;
614     let res = match my_var {
615         5 => 42i32,
616         _ => return Ok(24i32),
617     };
618     Ok(res)
619 }
620 "#,
621         );
622
623         check_assist(
624             wrap_return_type_in_result,
625             r#"
626 //- minicore: result
627 fn foo() -> i32$0 {
628     let my_var = 5;
629     let res = if my_var == 5 {
630         42i32
631     } else {
632         return 24i32;
633     };
634     res
635 }
636 "#,
637             r#"
638 fn foo() -> Result<i32, ${0:_}> {
639     let my_var = 5;
640     let res = if my_var == 5 {
641         42i32
642     } else {
643         return Ok(24i32);
644     };
645     Ok(res)
646 }
647 "#,
648         );
649     }
650
651     #[test]
652     fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
653         check_assist(
654             wrap_return_type_in_result,
655             r#"
656 //- minicore: result
657 fn foo() -> i32$0 {
658     let my_var = 5;
659     match my_var {
660         5 => {
661             if true {
662                 42i32
663             } else {
664                 25i32
665             }
666         },
667         _ => {
668             let test = "test";
669             if test == "test" {
670                 return bar();
671             }
672             53i32
673         },
674     }
675 }
676 "#,
677             r#"
678 fn foo() -> Result<i32, ${0:_}> {
679     let my_var = 5;
680     match my_var {
681         5 => {
682             if true {
683                 Ok(42i32)
684             } else {
685                 Ok(25i32)
686             }
687         },
688         _ => {
689             let test = "test";
690             if test == "test" {
691                 return Ok(bar());
692             }
693             Ok(53i32)
694         },
695     }
696 }
697 "#,
698         );
699     }
700
701     #[test]
702     fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
703         check_assist(
704             wrap_return_type_in_result,
705             r#"
706 //- minicore: result
707 fn foo() -> i$032 {
708     let test = "test";
709     if test == "test" {
710         return 24i32;
711     }
712     53i32
713 }
714 "#,
715             r#"
716 fn foo() -> Result<i32, ${0:_}> {
717     let test = "test";
718     if test == "test" {
719         return Ok(24i32);
720     }
721     Ok(53i32)
722 }
723 "#,
724         );
725     }
726
727     #[test]
728     fn wrap_return_type_in_result_simple_with_closure() {
729         check_assist(
730             wrap_return_type_in_result,
731             r#"
732 //- minicore: result
733 fn foo(the_field: u32) ->$0 u32 {
734     let true_closure = || { return true; };
735     if the_field < 5 {
736         let mut i = 0;
737         if true_closure() {
738             return 99;
739         } else {
740             return 0;
741         }
742     }
743     the_field
744 }
745 "#,
746             r#"
747 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
748     let true_closure = || { return true; };
749     if the_field < 5 {
750         let mut i = 0;
751         if true_closure() {
752             return Ok(99);
753         } else {
754             return Ok(0);
755         }
756     }
757     Ok(the_field)
758 }
759 "#,
760         );
761
762         check_assist(
763             wrap_return_type_in_result,
764             r#"
765 //- minicore: result
766 fn foo(the_field: u32) -> u32$0 {
767     let true_closure = || {
768         return true;
769     };
770     if the_field < 5 {
771         let mut i = 0;
772
773
774         if true_closure() {
775             return 99;
776         } else {
777             return 0;
778         }
779     }
780     let t = None;
781
782     t.unwrap_or_else(|| the_field)
783 }
784 "#,
785             r#"
786 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
787     let true_closure = || {
788         return true;
789     };
790     if the_field < 5 {
791         let mut i = 0;
792
793
794         if true_closure() {
795             return Ok(99);
796         } else {
797             return Ok(0);
798         }
799     }
800     let t = None;
801
802     Ok(t.unwrap_or_else(|| the_field))
803 }
804 "#,
805         );
806     }
807
808     #[test]
809     fn wrap_return_type_in_result_simple_with_weird_forms() {
810         check_assist(
811             wrap_return_type_in_result,
812             r#"
813 //- minicore: result
814 fn foo() -> i32$0 {
815     let test = "test";
816     if test == "test" {
817         return 24i32;
818     }
819     let mut i = 0;
820     loop {
821         if i == 1 {
822             break 55;
823         }
824         i += 1;
825     }
826 }
827 "#,
828             r#"
829 fn foo() -> Result<i32, ${0:_}> {
830     let test = "test";
831     if test == "test" {
832         return Ok(24i32);
833     }
834     let mut i = 0;
835     loop {
836         if i == 1 {
837             break Ok(55);
838         }
839         i += 1;
840     }
841 }
842 "#,
843         );
844
845         check_assist(
846             wrap_return_type_in_result,
847             r#"
848 //- minicore: result
849 fn foo(the_field: u32) -> u32$0 {
850     if the_field < 5 {
851         let mut i = 0;
852         loop {
853             if i > 5 {
854                 return 55u32;
855             }
856             i += 3;
857         }
858         match i {
859             5 => return 99,
860             _ => return 0,
861         };
862     }
863     the_field
864 }
865 "#,
866             r#"
867 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
868     if the_field < 5 {
869         let mut i = 0;
870         loop {
871             if i > 5 {
872                 return Ok(55u32);
873             }
874             i += 3;
875         }
876         match i {
877             5 => return Ok(99),
878             _ => return Ok(0),
879         };
880     }
881     Ok(the_field)
882 }
883 "#,
884         );
885
886         check_assist(
887             wrap_return_type_in_result,
888             r#"
889 //- minicore: result
890 fn foo(the_field: u32) -> u3$02 {
891     if the_field < 5 {
892         let mut i = 0;
893         match i {
894             5 => return 99,
895             _ => return 0,
896         }
897     }
898     the_field
899 }
900 "#,
901             r#"
902 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
903     if the_field < 5 {
904         let mut i = 0;
905         match i {
906             5 => return Ok(99),
907             _ => return Ok(0),
908         }
909     }
910     Ok(the_field)
911 }
912 "#,
913         );
914
915         check_assist(
916             wrap_return_type_in_result,
917             r#"
918 //- minicore: result
919 fn foo(the_field: u32) -> u32$0 {
920     if the_field < 5 {
921         let mut i = 0;
922         if i == 5 {
923             return 99
924         } else {
925             return 0
926         }
927     }
928     the_field
929 }
930 "#,
931             r#"
932 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
933     if the_field < 5 {
934         let mut i = 0;
935         if i == 5 {
936             return Ok(99)
937         } else {
938             return Ok(0)
939         }
940     }
941     Ok(the_field)
942 }
943 "#,
944         );
945
946         check_assist(
947             wrap_return_type_in_result,
948             r#"
949 //- minicore: result
950 fn foo(the_field: u32) -> $0u32 {
951     if the_field < 5 {
952         let mut i = 0;
953         if i == 5 {
954             return 99;
955         } else {
956             return 0;
957         }
958     }
959     the_field
960 }
961 "#,
962             r#"
963 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
964     if the_field < 5 {
965         let mut i = 0;
966         if i == 5 {
967             return Ok(99);
968         } else {
969             return Ok(0);
970         }
971     }
972     Ok(the_field)
973 }
974 "#,
975         );
976     }
977 }