]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/unwrap_result_return_type.rs
Merge #10417
[rust.git] / crates / ide_assists / src / handlers / unwrap_result_return_type.rs
1 use ide_db::helpers::{for_each_tail_expr, node_ext::walk_expr, FamousDefs};
2 use syntax::{
3     ast::{self, Expr},
4     match_ast, AstNode,
5 };
6
7 use crate::{AssistContext, AssistId, AssistKind, Assists};
8
9 // Assist: unwrap_result_return_type
10 //
11 // Unwrap the function's return type.
12 //
13 // ```
14 // # //- minicore: result
15 // fn foo() -> Result<i32>$0 { Ok(42i32) }
16 // ```
17 // ->
18 // ```
19 // fn foo() -> i32 { 42i32 }
20 // ```
21 pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
22     let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
23     let parent = ret_type.syntax().parent()?;
24     let body = match_ast! {
25         match parent {
26             ast::Fn(func) => func.body()?,
27             ast::ClosureExpr(closure) => match closure.body()? {
28                 Expr::BlockExpr(block) => block,
29                 // closures require a block when a return type is specified
30                 _ => return None,
31             },
32             _ => return None,
33         }
34     };
35
36     let type_ref = &ret_type.ty()?;
37     let ty = ctx.sema.resolve_type(type_ref).and_then(|ty| ty.as_adt());
38     let result_enum =
39         FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax()).krate()).core_result_Result()?;
40
41     if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
42         return None;
43     }
44
45     acc.add(
46         AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
47         "Unwrap Result return type",
48         type_ref.syntax().text_range(),
49         |builder| {
50             let body = ast::Expr::BlockExpr(body);
51
52             let mut exprs_to_unwrap = Vec::new();
53             let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_unwrap, e);
54             walk_expr(&body, &mut |expr| {
55                 if let Expr::ReturnExpr(ret_expr) = expr {
56                     if let Some(ret_expr_arg) = &ret_expr.expr() {
57                         for_each_tail_expr(ret_expr_arg, tail_cb);
58                     }
59                 }
60             });
61             for_each_tail_expr(&body, tail_cb);
62
63             for ret_expr_arg in exprs_to_unwrap {
64                 let new_ret_expr = ret_expr_arg.to_string();
65                 let new_ret_expr =
66                     new_ret_expr.trim_start_matches("Ok(").trim_start_matches("Err(");
67                 builder.replace(
68                     ret_expr_arg.syntax().text_range(),
69                     new_ret_expr.strip_suffix(')').unwrap_or(new_ret_expr),
70                 )
71             }
72
73             if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
74                 let inner_type = match inner_type.split_once(',') {
75                     Some((success_inner_type, _)) => success_inner_type,
76                     None => inner_type,
77                 };
78                 builder.replace(
79                     type_ref.syntax().text_range(),
80                     inner_type.strip_suffix('>').unwrap_or(inner_type),
81                 )
82             }
83         },
84     )
85 }
86
87 fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
88     match e {
89         Expr::BreakExpr(break_expr) => {
90             if let Some(break_expr_arg) = break_expr.expr() {
91                 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
92             }
93         }
94         Expr::ReturnExpr(ret_expr) => {
95             if let Some(ret_expr_arg) = &ret_expr.expr() {
96                 for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
97             }
98         }
99         e => acc.push(e.clone()),
100     }
101 }
102
103 #[cfg(test)]
104 mod tests {
105     use crate::tests::{check_assist, check_assist_not_applicable};
106
107     use super::*;
108
109     #[test]
110     fn unwrap_result_return_type_simple() {
111         check_assist(
112             unwrap_result_return_type,
113             r#"
114 //- minicore: result
115 fn foo() -> Result<i3$02> {
116     let test = "test";
117     return Ok(42i32);
118 }
119 "#,
120             r#"
121 fn foo() -> i32 {
122     let test = "test";
123     return 42i32;
124 }
125 "#,
126         );
127     }
128
129     #[test]
130     fn unwrap_return_type_break_split_tail() {
131         check_assist(
132             unwrap_result_return_type,
133             r#"
134 //- minicore: result
135 fn foo() -> Result<i3$02, String> {
136     loop {
137         break if true {
138             Ok(1)
139         } else {
140             Ok(0)
141         };
142     }
143 }
144 "#,
145             r#"
146 fn foo() -> i32 {
147     loop {
148         break if true {
149             1
150         } else {
151             0
152         };
153     }
154 }
155 "#,
156         );
157     }
158
159     #[test]
160     fn unwrap_result_return_type_simple_closure() {
161         check_assist(
162             unwrap_result_return_type,
163             r#"
164 //- minicore: result
165 fn foo() {
166     || -> Result<i32$0> {
167         let test = "test";
168         return Ok(42i32);
169     };
170 }
171 "#,
172             r#"
173 fn foo() {
174     || -> i32 {
175         let test = "test";
176         return 42i32;
177     };
178 }
179 "#,
180         );
181     }
182
183     #[test]
184     fn unwrap_result_return_type_simple_return_type_bad_cursor() {
185         check_assist_not_applicable(
186             unwrap_result_return_type,
187             r#"
188 //- minicore: result
189 fn foo() -> i32 {
190     let test = "test";$0
191     return 42i32;
192 }
193 "#,
194         );
195     }
196
197     #[test]
198     fn unwrap_result_return_type_simple_return_type_bad_cursor_closure() {
199         check_assist_not_applicable(
200             unwrap_result_return_type,
201             r#"
202 //- minicore: result
203 fn foo() {
204     || -> i32 {
205         let test = "test";$0
206         return 42i32;
207     };
208 }
209 "#,
210         );
211     }
212
213     #[test]
214     fn unwrap_result_return_type_closure_non_block() {
215         check_assist_not_applicable(
216             unwrap_result_return_type,
217             r#"
218 //- minicore: result
219 fn foo() { || -> i$032 3; }
220 "#,
221         );
222     }
223
224     #[test]
225     fn unwrap_result_return_type_simple_return_type_already_not_result_std() {
226         check_assist_not_applicable(
227             unwrap_result_return_type,
228             r#"
229 //- minicore: result
230 fn foo() -> i32$0 {
231     let test = "test";
232     return 42i32;
233 }
234 "#,
235         );
236     }
237
238     #[test]
239     fn unwrap_result_return_type_simple_return_type_already_not_result_closure() {
240         check_assist_not_applicable(
241             unwrap_result_return_type,
242             r#"
243 //- minicore: result
244 fn foo() {
245     || -> i32$0 {
246         let test = "test";
247         return 42i32;
248     };
249 }
250 "#,
251         );
252     }
253
254     #[test]
255     fn unwrap_result_return_type_simple_with_tail() {
256         check_assist(
257             unwrap_result_return_type,
258             r#"
259 //- minicore: result
260 fn foo() ->$0 Result<i32> {
261     let test = "test";
262     Ok(42i32)
263 }
264 "#,
265             r#"
266 fn foo() -> i32 {
267     let test = "test";
268     42i32
269 }
270 "#,
271         );
272     }
273
274     #[test]
275     fn unwrap_result_return_type_simple_with_tail_closure() {
276         check_assist(
277             unwrap_result_return_type,
278             r#"
279 //- minicore: result
280 fn foo() {
281     || ->$0 Result<i32, String> {
282         let test = "test";
283         Ok(42i32)
284     };
285 }
286 "#,
287             r#"
288 fn foo() {
289     || -> i32 {
290         let test = "test";
291         42i32
292     };
293 }
294 "#,
295         );
296     }
297
298     #[test]
299     fn unwrap_result_return_type_simple_with_tail_only() {
300         check_assist(
301             unwrap_result_return_type,
302             r#"
303 //- minicore: result
304 fn foo() -> Result<i32$0> { Ok(42i32) }
305 "#,
306             r#"
307 fn foo() -> i32 { 42i32 }
308 "#,
309         );
310     }
311
312     #[test]
313     fn unwrap_result_return_type_simple_with_tail_block_like() {
314         check_assist(
315             unwrap_result_return_type,
316             r#"
317 //- minicore: result
318 fn foo() -> Result<i32>$0 {
319     if true {
320         Ok(42i32)
321     } else {
322         Ok(24i32)
323     }
324 }
325 "#,
326             r#"
327 fn foo() -> i32 {
328     if true {
329         42i32
330     } else {
331         24i32
332     }
333 }
334 "#,
335         );
336     }
337
338     #[test]
339     fn unwrap_result_return_type_simple_without_block_closure() {
340         check_assist(
341             unwrap_result_return_type,
342             r#"
343 //- minicore: result
344 fn foo() {
345     || -> Result<i32, String>$0 {
346         if true {
347             Ok(42i32)
348         } else {
349             Ok(24i32)
350         }
351     };
352 }
353 "#,
354             r#"
355 fn foo() {
356     || -> i32 {
357         if true {
358             42i32
359         } else {
360             24i32
361         }
362     };
363 }
364 "#,
365         );
366     }
367
368     #[test]
369     fn unwrap_result_return_type_simple_with_nested_if() {
370         check_assist(
371             unwrap_result_return_type,
372             r#"
373 //- minicore: result
374 fn foo() -> Result<i32>$0 {
375     if true {
376         if false {
377             Ok(1)
378         } else {
379             Ok(2)
380         }
381     } else {
382         Ok(24i32)
383     }
384 }
385 "#,
386             r#"
387 fn foo() -> i32 {
388     if true {
389         if false {
390             1
391         } else {
392             2
393         }
394     } else {
395         24i32
396     }
397 }
398 "#,
399         );
400     }
401
402     #[test]
403     fn unwrap_result_return_type_simple_with_await() {
404         check_assist(
405             unwrap_result_return_type,
406             r#"
407 //- minicore: result
408 async fn foo() -> Result<i$032> {
409     if true {
410         if false {
411             Ok(1.await)
412         } else {
413             Ok(2.await)
414         }
415     } else {
416         Ok(24i32.await)
417     }
418 }
419 "#,
420             r#"
421 async fn foo() -> i32 {
422     if true {
423         if false {
424             1.await
425         } else {
426             2.await
427         }
428     } else {
429         24i32.await
430     }
431 }
432 "#,
433         );
434     }
435
436     #[test]
437     fn unwrap_result_return_type_simple_with_array() {
438         check_assist(
439             unwrap_result_return_type,
440             r#"
441 //- minicore: result
442 fn foo() -> Result<[i32; 3]$0> { Ok([1, 2, 3]) }
443 "#,
444             r#"
445 fn foo() -> [i32; 3] { [1, 2, 3] }
446 "#,
447         );
448     }
449
450     #[test]
451     fn unwrap_result_return_type_simple_with_cast() {
452         check_assist(
453             unwrap_result_return_type,
454             r#"
455 //- minicore: result
456 fn foo() -$0> Result<i32> {
457     if true {
458         if false {
459             Ok(1 as i32)
460         } else {
461             Ok(2 as i32)
462         }
463     } else {
464         Ok(24 as i32)
465     }
466 }
467 "#,
468             r#"
469 fn foo() -> i32 {
470     if true {
471         if false {
472             1 as i32
473         } else {
474             2 as i32
475         }
476     } else {
477         24 as i32
478     }
479 }
480 "#,
481         );
482     }
483
484     #[test]
485     fn unwrap_result_return_type_simple_with_tail_block_like_match() {
486         check_assist(
487             unwrap_result_return_type,
488             r#"
489 //- minicore: result
490 fn foo() -> Result<i32$0> {
491     let my_var = 5;
492     match my_var {
493         5 => Ok(42i32),
494         _ => Ok(24i32),
495     }
496 }
497 "#,
498             r#"
499 fn foo() -> i32 {
500     let my_var = 5;
501     match my_var {
502         5 => 42i32,
503         _ => 24i32,
504     }
505 }
506 "#,
507         );
508     }
509
510     #[test]
511     fn unwrap_result_return_type_simple_with_loop_with_tail() {
512         check_assist(
513             unwrap_result_return_type,
514             r#"
515 //- minicore: result
516 fn foo() -> Result<i32$0> {
517     let my_var = 5;
518     loop {
519         println!("test");
520         5
521     }
522     Ok(my_var)
523 }
524 "#,
525             r#"
526 fn foo() -> i32 {
527     let my_var = 5;
528     loop {
529         println!("test");
530         5
531     }
532     my_var
533 }
534 "#,
535         );
536     }
537
538     #[test]
539     fn unwrap_result_return_type_simple_with_loop_in_let_stmt() {
540         check_assist(
541             unwrap_result_return_type,
542             r#"
543 //- minicore: result
544 fn foo() -> Result<i32$0> {
545     let my_var = let x = loop {
546         break 1;
547     };
548     Ok(my_var)
549 }
550 "#,
551             r#"
552 fn foo() -> i32 {
553     let my_var = let x = loop {
554         break 1;
555     };
556     my_var
557 }
558 "#,
559         );
560     }
561
562     #[test]
563     fn unwrap_result_return_type_simple_with_tail_block_like_match_return_expr() {
564         check_assist(
565             unwrap_result_return_type,
566             r#"
567 //- minicore: result
568 fn foo() -> Result<i32>$0 {
569     let my_var = 5;
570     let res = match my_var {
571         5 => 42i32,
572         _ => return Ok(24i32),
573     };
574     Ok(res)
575 }
576 "#,
577             r#"
578 fn foo() -> i32 {
579     let my_var = 5;
580     let res = match my_var {
581         5 => 42i32,
582         _ => return 24i32,
583     };
584     res
585 }
586 "#,
587         );
588
589         check_assist(
590             unwrap_result_return_type,
591             r#"
592 //- minicore: result
593 fn foo() -> Result<i32$0> {
594     let my_var = 5;
595     let res = if my_var == 5 {
596         42i32
597     } else {
598         return Ok(24i32);
599     };
600     Ok(res)
601 }
602 "#,
603             r#"
604 fn foo() -> i32 {
605     let my_var = 5;
606     let res = if my_var == 5 {
607         42i32
608     } else {
609         return 24i32;
610     };
611     res
612 }
613 "#,
614         );
615     }
616
617     #[test]
618     fn unwrap_result_return_type_simple_with_tail_block_like_match_deeper() {
619         check_assist(
620             unwrap_result_return_type,
621             r#"
622 //- minicore: result
623 fn foo() -> Result<i32$0> {
624     let my_var = 5;
625     match my_var {
626         5 => {
627             if true {
628                 Ok(42i32)
629             } else {
630                 Ok(25i32)
631             }
632         },
633         _ => {
634             let test = "test";
635             if test == "test" {
636                 return Ok(bar());
637             }
638             Ok(53i32)
639         },
640     }
641 }
642 "#,
643             r#"
644 fn foo() -> i32 {
645     let my_var = 5;
646     match my_var {
647         5 => {
648             if true {
649                 42i32
650             } else {
651                 25i32
652             }
653         },
654         _ => {
655             let test = "test";
656             if test == "test" {
657                 return bar();
658             }
659             53i32
660         },
661     }
662 }
663 "#,
664         );
665     }
666
667     #[test]
668     fn unwrap_result_return_type_simple_with_tail_block_like_early_return() {
669         check_assist(
670             unwrap_result_return_type,
671             r#"
672 //- minicore: result
673 fn foo() -> Result<i32$0> {
674     let test = "test";
675     if test == "test" {
676         return Ok(24i32);
677     }
678     Ok(53i32)
679 }
680 "#,
681             r#"
682 fn foo() -> i32 {
683     let test = "test";
684     if test == "test" {
685         return 24i32;
686     }
687     53i32
688 }
689 "#,
690         );
691     }
692
693     #[test]
694     fn unwrap_result_return_type_simple_with_closure() {
695         check_assist(
696             unwrap_result_return_type,
697             r#"
698 //- minicore: result
699 fn foo(the_field: u32) -> Result<u32$0> {
700     let true_closure = || { return true; };
701     if the_field < 5 {
702         let mut i = 0;
703         if true_closure() {
704             return Ok(99);
705         } else {
706             return Ok(0);
707         }
708     }
709     Ok(the_field)
710 }
711 "#,
712             r#"
713 fn foo(the_field: u32) -> u32 {
714     let true_closure = || { return true; };
715     if the_field < 5 {
716         let mut i = 0;
717         if true_closure() {
718             return 99;
719         } else {
720             return 0;
721         }
722     }
723     the_field
724 }
725 "#,
726         );
727
728         check_assist(
729             unwrap_result_return_type,
730             r#"
731 //- minicore: result
732 fn foo(the_field: u32) -> Result<u32$0> {
733     let true_closure = || {
734         return true;
735     };
736     if the_field < 5 {
737         let mut i = 0;
738
739
740         if true_closure() {
741             return Ok(99);
742         } else {
743             return Ok(0);
744         }
745     }
746     let t = None;
747
748     Ok(t.unwrap_or_else(|| the_field))
749 }
750 "#,
751             r#"
752 fn foo(the_field: u32) -> u32 {
753     let true_closure = || {
754         return true;
755     };
756     if the_field < 5 {
757         let mut i = 0;
758
759
760         if true_closure() {
761             return 99;
762         } else {
763             return 0;
764         }
765     }
766     let t = None;
767
768     t.unwrap_or_else(|| the_field)
769 }
770 "#,
771         );
772     }
773
774     #[test]
775     fn unwrap_result_return_type_simple_with_weird_forms() {
776         check_assist(
777             unwrap_result_return_type,
778             r#"
779 //- minicore: result
780 fn foo() -> Result<i32$0> {
781     let test = "test";
782     if test == "test" {
783         return Ok(24i32);
784     }
785     let mut i = 0;
786     loop {
787         if i == 1 {
788             break Ok(55);
789         }
790         i += 1;
791     }
792 }
793 "#,
794             r#"
795 fn foo() -> i32 {
796     let test = "test";
797     if test == "test" {
798         return 24i32;
799     }
800     let mut i = 0;
801     loop {
802         if i == 1 {
803             break 55;
804         }
805         i += 1;
806     }
807 }
808 "#,
809         );
810
811         check_assist(
812             unwrap_result_return_type,
813             r#"
814 //- minicore: result
815 fn foo(the_field: u32) -> Result<u32$0> {
816     if the_field < 5 {
817         let mut i = 0;
818         loop {
819             if i > 5 {
820                 return Ok(55u32);
821             }
822             i += 3;
823         }
824         match i {
825             5 => return Ok(99),
826             _ => return Ok(0),
827         };
828     }
829     Ok(the_field)
830 }
831 "#,
832             r#"
833 fn foo(the_field: u32) -> u32 {
834     if the_field < 5 {
835         let mut i = 0;
836         loop {
837             if i > 5 {
838                 return 55u32;
839             }
840             i += 3;
841         }
842         match i {
843             5 => return 99,
844             _ => return 0,
845         };
846     }
847     the_field
848 }
849 "#,
850         );
851
852         check_assist(
853             unwrap_result_return_type,
854             r#"
855 //- minicore: result
856 fn foo(the_field: u32) -> Result<u32$0> {
857     if the_field < 5 {
858         let mut i = 0;
859         match i {
860             5 => return Ok(99),
861             _ => return Ok(0),
862         }
863     }
864     Ok(the_field)
865 }
866 "#,
867             r#"
868 fn foo(the_field: u32) -> u32 {
869     if the_field < 5 {
870         let mut i = 0;
871         match i {
872             5 => return 99,
873             _ => return 0,
874         }
875     }
876     the_field
877 }
878 "#,
879         );
880
881         check_assist(
882             unwrap_result_return_type,
883             r#"
884 //- minicore: result
885 fn foo(the_field: u32) -> Result<u32$0> {
886     if the_field < 5 {
887         let mut i = 0;
888         if i == 5 {
889             return Ok(99)
890         } else {
891             return Ok(0)
892         }
893     }
894     Ok(the_field)
895 }
896 "#,
897             r#"
898 fn foo(the_field: u32) -> u32 {
899     if the_field < 5 {
900         let mut i = 0;
901         if i == 5 {
902             return 99
903         } else {
904             return 0
905         }
906     }
907     the_field
908 }
909 "#,
910         );
911
912         check_assist(
913             unwrap_result_return_type,
914             r#"
915 //- minicore: result
916 fn foo(the_field: u32) -> Result<u3$02> {
917     if the_field < 5 {
918         let mut i = 0;
919         if i == 5 {
920             return Ok(99);
921         } else {
922             return Ok(0);
923         }
924     }
925     Ok(the_field)
926 }
927 "#,
928             r#"
929 fn foo(the_field: u32) -> u32 {
930     if the_field < 5 {
931         let mut i = 0;
932         if i == 5 {
933             return 99;
934         } else {
935             return 0;
936         }
937     }
938     the_field
939 }
940 "#,
941         );
942     }
943 }