]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/unwrap_result_return_type.rs
Merge #9939
[rust.git] / crates / ide_assists / src / handlers / unwrap_result_return_type.rs
1 use ide_db::helpers::{for_each_tail_expr, node_ext::walk_expr, FamousDefs};
2 use itertools::Itertools;
3 use syntax::{
4     ast::{self, Expr},
5     match_ast, AstNode, TextRange, TextSize,
6 };
7
8 use crate::{AssistContext, AssistId, AssistKind, Assists};
9
10 // Assist: unwrap_result_return_type
11 //
12 // Unwrap the function's return type.
13 //
14 // ```
15 // # //- minicore: result
16 // fn foo() -> Result<i32>$0 { Ok(42i32) }
17 // ```
18 // ->
19 // ```
20 // fn foo() -> i32 { 42i32 }
21 // ```
22 pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
23     let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
24     let parent = ret_type.syntax().parent()?;
25     let body = match_ast! {
26         match parent {
27             ast::Fn(func) => func.body()?,
28             ast::ClosureExpr(closure) => match closure.body()? {
29                 Expr::BlockExpr(block) => block,
30                 // closures require a block when a return type is specified
31                 _ => return None,
32             },
33             _ => return None,
34         }
35     };
36
37     let type_ref = &ret_type.ty()?;
38     let ty = ctx.sema.resolve_type(type_ref).and_then(|ty| ty.as_adt());
39     let result_enum =
40         FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax()).krate()).core_result_Result()?;
41
42     if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
43         return None;
44     }
45
46     acc.add(
47         AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
48         "Unwrap Result return type",
49         type_ref.syntax().text_range(),
50         |builder| {
51             let body = ast::Expr::BlockExpr(body);
52
53             let mut exprs_to_unwrap = Vec::new();
54             let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_unwrap, e);
55             walk_expr(&body, &mut |expr| {
56                 if let Expr::ReturnExpr(ret_expr) = expr {
57                     if let Some(ret_expr_arg) = &ret_expr.expr() {
58                         for_each_tail_expr(ret_expr_arg, tail_cb);
59                     }
60                 }
61             });
62             for_each_tail_expr(&body, tail_cb);
63
64             let mut is_unit_type = false;
65             if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
66                 let inner_type = match inner_type.split_once(',') {
67                     Some((success_inner_type, _)) => success_inner_type,
68                     None => inner_type,
69                 };
70                 let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type);
71                 if new_ret_type == "()" {
72                     is_unit_type = true;
73                     let text_range = TextRange::new(
74                         ret_type.syntax().text_range().start(),
75                         ret_type.syntax().text_range().end() + TextSize::from(1u32),
76                     );
77                     builder.delete(text_range)
78                 } else {
79                     builder.replace(
80                         type_ref.syntax().text_range(),
81                         inner_type.strip_suffix('>').unwrap_or(inner_type),
82                     )
83                 }
84             }
85
86             for ret_expr_arg in exprs_to_unwrap {
87                 let ret_expr_str = ret_expr_arg.to_string();
88                 if ret_expr_str.starts_with("Ok(") || ret_expr_str.starts_with("Err(") {
89                     let arg_list = ret_expr_arg.syntax().children().find_map(ast::ArgList::cast);
90                     if let Some(arg_list) = arg_list {
91                         if is_unit_type {
92                             match ret_expr_arg.syntax().prev_sibling_or_token() {
93                                 // Useful to delete the entire line without leaving trailing whitespaces
94                                 Some(whitespace) => {
95                                     let new_range = TextRange::new(
96                                         whitespace.text_range().start(),
97                                         ret_expr_arg.syntax().text_range().end(),
98                                     );
99                                     builder.delete(new_range);
100                                 }
101                                 None => {
102                                     builder.delete(ret_expr_arg.syntax().text_range());
103                                 }
104                             }
105                         } else {
106                             builder.replace(
107                                 ret_expr_arg.syntax().text_range(),
108                                 arg_list.args().join(", "),
109                             );
110                         }
111                     }
112                 }
113             }
114         },
115     )
116 }
117
118 fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
119     match e {
120         Expr::BreakExpr(break_expr) => {
121             if let Some(break_expr_arg) = break_expr.expr() {
122                 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
123             }
124         }
125         Expr::ReturnExpr(ret_expr) => {
126             if let Some(ret_expr_arg) = &ret_expr.expr() {
127                 for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
128             }
129         }
130         e => acc.push(e.clone()),
131     }
132 }
133
134 #[cfg(test)]
135 mod tests {
136     use crate::tests::{check_assist, check_assist_not_applicable};
137
138     use super::*;
139
140     #[test]
141     fn unwrap_result_return_type_simple() {
142         check_assist(
143             unwrap_result_return_type,
144             r#"
145 //- minicore: result
146 fn foo() -> Result<i3$02> {
147     let test = "test";
148     return Ok(42i32);
149 }
150 "#,
151             r#"
152 fn foo() -> i32 {
153     let test = "test";
154     return 42i32;
155 }
156 "#,
157         );
158     }
159
160     #[test]
161     fn unwrap_result_return_type_unit_type() {
162         check_assist(
163             unwrap_result_return_type,
164             r#"
165 //- minicore: result
166 fn foo() -> Result<(), Box<dyn Error$0>> {
167     Ok(())
168 }
169 "#,
170             r#"
171 fn foo() {
172 }
173 "#,
174         );
175     }
176
177     #[test]
178     fn unwrap_result_return_type_ending_with_parent() {
179         check_assist(
180             unwrap_result_return_type,
181             r#"
182 //- minicore: result
183 fn foo() -> Result<i32, Box<dyn Error$0>> {
184     if true {
185         Ok(42)
186     } else {
187         foo()
188     }
189 }
190 "#,
191             r#"
192 fn foo() -> i32 {
193     if true {
194         42
195     } else {
196         foo()
197     }
198 }
199 "#,
200         );
201     }
202
203     #[test]
204     fn unwrap_return_type_break_split_tail() {
205         check_assist(
206             unwrap_result_return_type,
207             r#"
208 //- minicore: result
209 fn foo() -> Result<i3$02, String> {
210     loop {
211         break if true {
212             Ok(1)
213         } else {
214             Ok(0)
215         };
216     }
217 }
218 "#,
219             r#"
220 fn foo() -> i32 {
221     loop {
222         break if true {
223             1
224         } else {
225             0
226         };
227     }
228 }
229 "#,
230         );
231     }
232
233     #[test]
234     fn unwrap_result_return_type_simple_closure() {
235         check_assist(
236             unwrap_result_return_type,
237             r#"
238 //- minicore: result
239 fn foo() {
240     || -> Result<i32$0> {
241         let test = "test";
242         return Ok(42i32);
243     };
244 }
245 "#,
246             r#"
247 fn foo() {
248     || -> i32 {
249         let test = "test";
250         return 42i32;
251     };
252 }
253 "#,
254         );
255     }
256
257     #[test]
258     fn unwrap_result_return_type_simple_return_type_bad_cursor() {
259         check_assist_not_applicable(
260             unwrap_result_return_type,
261             r#"
262 //- minicore: result
263 fn foo() -> i32 {
264     let test = "test";$0
265     return 42i32;
266 }
267 "#,
268         );
269     }
270
271     #[test]
272     fn unwrap_result_return_type_simple_return_type_bad_cursor_closure() {
273         check_assist_not_applicable(
274             unwrap_result_return_type,
275             r#"
276 //- minicore: result
277 fn foo() {
278     || -> i32 {
279         let test = "test";$0
280         return 42i32;
281     };
282 }
283 "#,
284         );
285     }
286
287     #[test]
288     fn unwrap_result_return_type_closure_non_block() {
289         check_assist_not_applicable(
290             unwrap_result_return_type,
291             r#"
292 //- minicore: result
293 fn foo() { || -> i$032 3; }
294 "#,
295         );
296     }
297
298     #[test]
299     fn unwrap_result_return_type_simple_return_type_already_not_result_std() {
300         check_assist_not_applicable(
301             unwrap_result_return_type,
302             r#"
303 //- minicore: result
304 fn foo() -> i32$0 {
305     let test = "test";
306     return 42i32;
307 }
308 "#,
309         );
310     }
311
312     #[test]
313     fn unwrap_result_return_type_simple_return_type_already_not_result_closure() {
314         check_assist_not_applicable(
315             unwrap_result_return_type,
316             r#"
317 //- minicore: result
318 fn foo() {
319     || -> i32$0 {
320         let test = "test";
321         return 42i32;
322     };
323 }
324 "#,
325         );
326     }
327
328     #[test]
329     fn unwrap_result_return_type_simple_with_tail() {
330         check_assist(
331             unwrap_result_return_type,
332             r#"
333 //- minicore: result
334 fn foo() ->$0 Result<i32> {
335     let test = "test";
336     Ok(42i32)
337 }
338 "#,
339             r#"
340 fn foo() -> i32 {
341     let test = "test";
342     42i32
343 }
344 "#,
345         );
346     }
347
348     #[test]
349     fn unwrap_result_return_type_simple_with_tail_closure() {
350         check_assist(
351             unwrap_result_return_type,
352             r#"
353 //- minicore: result
354 fn foo() {
355     || ->$0 Result<i32, String> {
356         let test = "test";
357         Ok(42i32)
358     };
359 }
360 "#,
361             r#"
362 fn foo() {
363     || -> i32 {
364         let test = "test";
365         42i32
366     };
367 }
368 "#,
369         );
370     }
371
372     #[test]
373     fn unwrap_result_return_type_simple_with_tail_only() {
374         check_assist(
375             unwrap_result_return_type,
376             r#"
377 //- minicore: result
378 fn foo() -> Result<i32$0> { Ok(42i32) }
379 "#,
380             r#"
381 fn foo() -> i32 { 42i32 }
382 "#,
383         );
384     }
385
386     #[test]
387     fn unwrap_result_return_type_simple_with_tail_block_like() {
388         check_assist(
389             unwrap_result_return_type,
390             r#"
391 //- minicore: result
392 fn foo() -> Result<i32>$0 {
393     if true {
394         Ok(42i32)
395     } else {
396         Ok(24i32)
397     }
398 }
399 "#,
400             r#"
401 fn foo() -> i32 {
402     if true {
403         42i32
404     } else {
405         24i32
406     }
407 }
408 "#,
409         );
410     }
411
412     #[test]
413     fn unwrap_result_return_type_simple_without_block_closure() {
414         check_assist(
415             unwrap_result_return_type,
416             r#"
417 //- minicore: result
418 fn foo() {
419     || -> Result<i32, String>$0 {
420         if true {
421             Ok(42i32)
422         } else {
423             Ok(24i32)
424         }
425     };
426 }
427 "#,
428             r#"
429 fn foo() {
430     || -> i32 {
431         if true {
432             42i32
433         } else {
434             24i32
435         }
436     };
437 }
438 "#,
439         );
440     }
441
442     #[test]
443     fn unwrap_result_return_type_simple_with_nested_if() {
444         check_assist(
445             unwrap_result_return_type,
446             r#"
447 //- minicore: result
448 fn foo() -> Result<i32>$0 {
449     if true {
450         if false {
451             Ok(1)
452         } else {
453             Ok(2)
454         }
455     } else {
456         Ok(24i32)
457     }
458 }
459 "#,
460             r#"
461 fn foo() -> i32 {
462     if true {
463         if false {
464             1
465         } else {
466             2
467         }
468     } else {
469         24i32
470     }
471 }
472 "#,
473         );
474     }
475
476     #[test]
477     fn unwrap_result_return_type_simple_with_await() {
478         check_assist(
479             unwrap_result_return_type,
480             r#"
481 //- minicore: result
482 async fn foo() -> Result<i$032> {
483     if true {
484         if false {
485             Ok(1.await)
486         } else {
487             Ok(2.await)
488         }
489     } else {
490         Ok(24i32.await)
491     }
492 }
493 "#,
494             r#"
495 async fn foo() -> i32 {
496     if true {
497         if false {
498             1.await
499         } else {
500             2.await
501         }
502     } else {
503         24i32.await
504     }
505 }
506 "#,
507         );
508     }
509
510     #[test]
511     fn unwrap_result_return_type_simple_with_array() {
512         check_assist(
513             unwrap_result_return_type,
514             r#"
515 //- minicore: result
516 fn foo() -> Result<[i32; 3]$0> { Ok([1, 2, 3]) }
517 "#,
518             r#"
519 fn foo() -> [i32; 3] { [1, 2, 3] }
520 "#,
521         );
522     }
523
524     #[test]
525     fn unwrap_result_return_type_simple_with_cast() {
526         check_assist(
527             unwrap_result_return_type,
528             r#"
529 //- minicore: result
530 fn foo() -$0> Result<i32> {
531     if true {
532         if false {
533             Ok(1 as i32)
534         } else {
535             Ok(2 as i32)
536         }
537     } else {
538         Ok(24 as i32)
539     }
540 }
541 "#,
542             r#"
543 fn foo() -> i32 {
544     if true {
545         if false {
546             1 as i32
547         } else {
548             2 as i32
549         }
550     } else {
551         24 as i32
552     }
553 }
554 "#,
555         );
556     }
557
558     #[test]
559     fn unwrap_result_return_type_simple_with_tail_block_like_match() {
560         check_assist(
561             unwrap_result_return_type,
562             r#"
563 //- minicore: result
564 fn foo() -> Result<i32$0> {
565     let my_var = 5;
566     match my_var {
567         5 => Ok(42i32),
568         _ => Ok(24i32),
569     }
570 }
571 "#,
572             r#"
573 fn foo() -> i32 {
574     let my_var = 5;
575     match my_var {
576         5 => 42i32,
577         _ => 24i32,
578     }
579 }
580 "#,
581         );
582     }
583
584     #[test]
585     fn unwrap_result_return_type_simple_with_loop_with_tail() {
586         check_assist(
587             unwrap_result_return_type,
588             r#"
589 //- minicore: result
590 fn foo() -> Result<i32$0> {
591     let my_var = 5;
592     loop {
593         println!("test");
594         5
595     }
596     Ok(my_var)
597 }
598 "#,
599             r#"
600 fn foo() -> i32 {
601     let my_var = 5;
602     loop {
603         println!("test");
604         5
605     }
606     my_var
607 }
608 "#,
609         );
610     }
611
612     #[test]
613     fn unwrap_result_return_type_simple_with_loop_in_let_stmt() {
614         check_assist(
615             unwrap_result_return_type,
616             r#"
617 //- minicore: result
618 fn foo() -> Result<i32$0> {
619     let my_var = let x = loop {
620         break 1;
621     };
622     Ok(my_var)
623 }
624 "#,
625             r#"
626 fn foo() -> i32 {
627     let my_var = let x = loop {
628         break 1;
629     };
630     my_var
631 }
632 "#,
633         );
634     }
635
636     #[test]
637     fn unwrap_result_return_type_simple_with_tail_block_like_match_return_expr() {
638         check_assist(
639             unwrap_result_return_type,
640             r#"
641 //- minicore: result
642 fn foo() -> Result<i32>$0 {
643     let my_var = 5;
644     let res = match my_var {
645         5 => 42i32,
646         _ => return Ok(24i32),
647     };
648     Ok(res)
649 }
650 "#,
651             r#"
652 fn foo() -> i32 {
653     let my_var = 5;
654     let res = match my_var {
655         5 => 42i32,
656         _ => return 24i32,
657     };
658     res
659 }
660 "#,
661         );
662
663         check_assist(
664             unwrap_result_return_type,
665             r#"
666 //- minicore: result
667 fn foo() -> Result<i32$0> {
668     let my_var = 5;
669     let res = if my_var == 5 {
670         42i32
671     } else {
672         return Ok(24i32);
673     };
674     Ok(res)
675 }
676 "#,
677             r#"
678 fn foo() -> i32 {
679     let my_var = 5;
680     let res = if my_var == 5 {
681         42i32
682     } else {
683         return 24i32;
684     };
685     res
686 }
687 "#,
688         );
689     }
690
691     #[test]
692     fn unwrap_result_return_type_simple_with_tail_block_like_match_deeper() {
693         check_assist(
694             unwrap_result_return_type,
695             r#"
696 //- minicore: result
697 fn foo() -> Result<i32$0> {
698     let my_var = 5;
699     match my_var {
700         5 => {
701             if true {
702                 Ok(42i32)
703             } else {
704                 Ok(25i32)
705             }
706         },
707         _ => {
708             let test = "test";
709             if test == "test" {
710                 return Ok(bar());
711             }
712             Ok(53i32)
713         },
714     }
715 }
716 "#,
717             r#"
718 fn foo() -> i32 {
719     let my_var = 5;
720     match my_var {
721         5 => {
722             if true {
723                 42i32
724             } else {
725                 25i32
726             }
727         },
728         _ => {
729             let test = "test";
730             if test == "test" {
731                 return bar();
732             }
733             53i32
734         },
735     }
736 }
737 "#,
738         );
739     }
740
741     #[test]
742     fn unwrap_result_return_type_simple_with_tail_block_like_early_return() {
743         check_assist(
744             unwrap_result_return_type,
745             r#"
746 //- minicore: result
747 fn foo() -> Result<i32$0> {
748     let test = "test";
749     if test == "test" {
750         return Ok(24i32);
751     }
752     Ok(53i32)
753 }
754 "#,
755             r#"
756 fn foo() -> i32 {
757     let test = "test";
758     if test == "test" {
759         return 24i32;
760     }
761     53i32
762 }
763 "#,
764         );
765     }
766
767     #[test]
768     fn unwrap_result_return_type_simple_with_closure() {
769         check_assist(
770             unwrap_result_return_type,
771             r#"
772 //- minicore: result
773 fn foo(the_field: u32) -> Result<u32$0> {
774     let true_closure = || { return true; };
775     if the_field < 5 {
776         let mut i = 0;
777         if true_closure() {
778             return Ok(99);
779         } else {
780             return Ok(0);
781         }
782     }
783     Ok(the_field)
784 }
785 "#,
786             r#"
787 fn foo(the_field: u32) -> u32 {
788     let true_closure = || { return true; };
789     if the_field < 5 {
790         let mut i = 0;
791         if true_closure() {
792             return 99;
793         } else {
794             return 0;
795         }
796     }
797     the_field
798 }
799 "#,
800         );
801
802         check_assist(
803             unwrap_result_return_type,
804             r#"
805 //- minicore: result
806 fn foo(the_field: u32) -> Result<u32$0> {
807     let true_closure = || {
808         return true;
809     };
810     if the_field < 5 {
811         let mut i = 0;
812
813
814         if true_closure() {
815             return Ok(99);
816         } else {
817             return Ok(0);
818         }
819     }
820     let t = None;
821
822     Ok(t.unwrap_or_else(|| the_field))
823 }
824 "#,
825             r#"
826 fn foo(the_field: u32) -> u32 {
827     let true_closure = || {
828         return true;
829     };
830     if the_field < 5 {
831         let mut i = 0;
832
833
834         if true_closure() {
835             return 99;
836         } else {
837             return 0;
838         }
839     }
840     let t = None;
841
842     t.unwrap_or_else(|| the_field)
843 }
844 "#,
845         );
846     }
847
848     #[test]
849     fn unwrap_result_return_type_simple_with_weird_forms() {
850         check_assist(
851             unwrap_result_return_type,
852             r#"
853 //- minicore: result
854 fn foo() -> Result<i32$0> {
855     let test = "test";
856     if test == "test" {
857         return Ok(24i32);
858     }
859     let mut i = 0;
860     loop {
861         if i == 1 {
862             break Ok(55);
863         }
864         i += 1;
865     }
866 }
867 "#,
868             r#"
869 fn foo() -> i32 {
870     let test = "test";
871     if test == "test" {
872         return 24i32;
873     }
874     let mut i = 0;
875     loop {
876         if i == 1 {
877             break 55;
878         }
879         i += 1;
880     }
881 }
882 "#,
883         );
884
885         check_assist(
886             unwrap_result_return_type,
887             r#"
888 //- minicore: result
889 fn foo(the_field: u32) -> Result<u32$0> {
890     if the_field < 5 {
891         let mut i = 0;
892         loop {
893             if i > 5 {
894                 return Ok(55u32);
895             }
896             i += 3;
897         }
898         match i {
899             5 => return Ok(99),
900             _ => return Ok(0),
901         };
902     }
903     Ok(the_field)
904 }
905 "#,
906             r#"
907 fn foo(the_field: u32) -> u32 {
908     if the_field < 5 {
909         let mut i = 0;
910         loop {
911             if i > 5 {
912                 return 55u32;
913             }
914             i += 3;
915         }
916         match i {
917             5 => return 99,
918             _ => return 0,
919         };
920     }
921     the_field
922 }
923 "#,
924         );
925
926         check_assist(
927             unwrap_result_return_type,
928             r#"
929 //- minicore: result
930 fn foo(the_field: u32) -> Result<u32$0> {
931     if the_field < 5 {
932         let mut i = 0;
933         match i {
934             5 => return Ok(99),
935             _ => return Ok(0),
936         }
937     }
938     Ok(the_field)
939 }
940 "#,
941             r#"
942 fn foo(the_field: u32) -> u32 {
943     if the_field < 5 {
944         let mut i = 0;
945         match i {
946             5 => return 99,
947             _ => return 0,
948         }
949     }
950     the_field
951 }
952 "#,
953         );
954
955         check_assist(
956             unwrap_result_return_type,
957             r#"
958 //- minicore: result
959 fn foo(the_field: u32) -> Result<u32$0> {
960     if the_field < 5 {
961         let mut i = 0;
962         if i == 5 {
963             return Ok(99)
964         } else {
965             return Ok(0)
966         }
967     }
968     Ok(the_field)
969 }
970 "#,
971             r#"
972 fn foo(the_field: u32) -> u32 {
973     if the_field < 5 {
974         let mut i = 0;
975         if i == 5 {
976             return 99
977         } else {
978             return 0
979         }
980     }
981     the_field
982 }
983 "#,
984         );
985
986         check_assist(
987             unwrap_result_return_type,
988             r#"
989 //- minicore: result
990 fn foo(the_field: u32) -> Result<u3$02> {
991     if the_field < 5 {
992         let mut i = 0;
993         if i == 5 {
994             return Ok(99);
995         } else {
996             return Ok(0);
997         }
998     }
999     Ok(the_field)
1000 }
1001 "#,
1002             r#"
1003 fn foo(the_field: u32) -> u32 {
1004     if the_field < 5 {
1005         let mut i = 0;
1006         if i == 5 {
1007             return 99;
1008         } else {
1009             return 0;
1010         }
1011     }
1012     the_field
1013 }
1014 "#,
1015         );
1016     }
1017 }