]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/wrap_return_type_in_result.rs
Deduplicate ast expression walking logic
[rust.git] / crates / ide_assists / src / handlers / wrap_return_type_in_result.rs
1 use std::iter;
2
3 use ide_db::helpers::for_each_tail_expr;
4 use syntax::{
5     ast::{self, make, Expr},
6     match_ast, AstNode,
7 };
8
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
10
11 // Assist: wrap_return_type_in_result
12 //
13 // Wrap the function's return type into Result.
14 //
15 // ```
16 // fn foo() -> i32$0 { 42i32 }
17 // ```
18 // ->
19 // ```
20 // fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
21 // ```
22 pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
23     let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
24     let parent = ret_type.syntax().parent()?;
25     let body = match_ast! {
26         match parent {
27             ast::Fn(func) => func.body()?,
28             ast::ClosureExpr(closure) => match closure.body()? {
29                 Expr::BlockExpr(block) => block,
30                 // closures require a block when a return type is specified
31                 _ => return None,
32             },
33             _ => return None,
34         }
35     };
36     let body = ast::Expr::BlockExpr(body);
37
38     let type_ref = &ret_type.ty()?;
39     let ret_type_str = type_ref.syntax().text().to_string();
40     let first_part_ret_type = ret_type_str.splitn(2, '<').next();
41     if let Some(ret_type_first_part) = first_part_ret_type {
42         if ret_type_first_part.ends_with("Result") {
43             cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
44             return None;
45         }
46     }
47
48     acc.add(
49         AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
50         "Wrap return type in Result",
51         type_ref.syntax().text_range(),
52         |builder| {
53             let mut exprs_to_wrap = Vec::new();
54             let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e);
55             body.walk(&mut |expr| {
56                 if let Expr::ReturnExpr(ret_expr) = expr {
57                     if let Some(ret_expr_arg) = &ret_expr.expr() {
58                         for_each_tail_expr(ret_expr_arg, tail_cb);
59                     }
60                 }
61             });
62             for_each_tail_expr(&body, tail_cb);
63
64             for ret_expr_arg in exprs_to_wrap {
65                 let ok_wrapped = make::expr_call(
66                     make::expr_path(make::ext::ident_path("Ok")),
67                     make::arg_list(iter::once(ret_expr_arg.clone())),
68                 );
69                 builder.replace_ast(ret_expr_arg, ok_wrapped);
70             }
71
72             match ctx.config.snippet_cap {
73                 Some(cap) => {
74                     let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
75                     builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
76                 }
77                 None => builder
78                     .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
79             }
80         },
81     )
82 }
83
84 fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
85     match e {
86         Expr::BreakExpr(break_expr) => {
87             if let Some(break_expr_arg) = break_expr.expr() {
88                 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
89             }
90         }
91         e => acc.push(e.clone()),
92     }
93 }
94
95 #[cfg(test)]
96 mod tests {
97     use crate::tests::{check_assist, check_assist_not_applicable};
98
99     use super::*;
100
101     #[test]
102     fn wrap_return_type_in_result_simple() {
103         check_assist(
104             wrap_return_type_in_result,
105             r#"
106 fn foo() -> i3$02 {
107     let test = "test";
108     return 42i32;
109 }
110 "#,
111             r#"
112 fn foo() -> Result<i32, ${0:_}> {
113     let test = "test";
114     return Ok(42i32);
115 }
116 "#,
117         );
118     }
119
120     #[test]
121     fn wrap_return_type_break_split_tail() {
122         check_assist(
123             wrap_return_type_in_result,
124             r#"
125 fn foo() -> i3$02 {
126     loop {
127         break if true {
128             1
129         } else {
130             0
131         };
132     }
133 }
134 "#,
135             r#"
136 fn foo() -> Result<i32, ${0:_}> {
137     loop {
138         break if true {
139             Ok(1)
140         } else {
141             Ok(0)
142         };
143     }
144 }
145 "#,
146         );
147     }
148
149     #[test]
150     fn wrap_return_type_in_result_simple_closure() {
151         check_assist(
152             wrap_return_type_in_result,
153             r#"
154 fn foo() {
155     || -> i32$0 {
156         let test = "test";
157         return 42i32;
158     };
159 }
160 "#,
161             r#"
162 fn foo() {
163     || -> Result<i32, ${0:_}> {
164         let test = "test";
165         return Ok(42i32);
166     };
167 }
168 "#,
169         );
170     }
171
172     #[test]
173     fn wrap_return_type_in_result_simple_return_type_bad_cursor() {
174         check_assist_not_applicable(
175             wrap_return_type_in_result,
176             r#"
177 fn foo() -> i32 {
178     let test = "test";$0
179     return 42i32;
180 }
181 "#,
182         );
183     }
184
185     #[test]
186     fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() {
187         check_assist_not_applicable(
188             wrap_return_type_in_result,
189             r#"
190 fn foo() {
191     || -> i32 {
192         let test = "test";$0
193         return 42i32;
194     };
195 }
196 "#,
197         );
198     }
199
200     #[test]
201     fn wrap_return_type_in_result_closure_non_block() {
202         check_assist_not_applicable(wrap_return_type_in_result, r#"fn foo() { || -> i$032 3; }"#);
203     }
204
205     #[test]
206     fn wrap_return_type_in_result_simple_return_type_already_result_std() {
207         check_assist_not_applicable(
208             wrap_return_type_in_result,
209             r#"
210 fn foo() -> std::result::Result<i32$0, String> {
211     let test = "test";
212     return 42i32;
213 }
214 "#,
215         );
216     }
217
218     #[test]
219     fn wrap_return_type_in_result_simple_return_type_already_result() {
220         cov_mark::check!(wrap_return_type_in_result_simple_return_type_already_result);
221         check_assist_not_applicable(
222             wrap_return_type_in_result,
223             r#"
224 fn foo() -> Result<i32$0, String> {
225     let test = "test";
226     return 42i32;
227 }
228 "#,
229         );
230     }
231
232     #[test]
233     fn wrap_return_type_in_result_simple_return_type_already_result_closure() {
234         check_assist_not_applicable(
235             wrap_return_type_in_result,
236             r#"
237 fn foo() {
238     || -> Result<i32$0, String> {
239         let test = "test";
240         return 42i32;
241     };
242 }
243 "#,
244         );
245     }
246
247     #[test]
248     fn wrap_return_type_in_result_simple_with_cursor() {
249         check_assist(
250             wrap_return_type_in_result,
251             r#"
252 fn foo() -> $0i32 {
253     let test = "test";
254     return 42i32;
255 }
256 "#,
257             r#"
258 fn foo() -> Result<i32, ${0:_}> {
259     let test = "test";
260     return Ok(42i32);
261 }
262 "#,
263         );
264     }
265
266     #[test]
267     fn wrap_return_type_in_result_simple_with_tail() {
268         check_assist(
269             wrap_return_type_in_result,
270             r#"
271 fn foo() ->$0 i32 {
272     let test = "test";
273     42i32
274 }
275 "#,
276             r#"
277 fn foo() -> Result<i32, ${0:_}> {
278     let test = "test";
279     Ok(42i32)
280 }
281 "#,
282         );
283     }
284
285     #[test]
286     fn wrap_return_type_in_result_simple_with_tail_closure() {
287         check_assist(
288             wrap_return_type_in_result,
289             r#"
290 fn foo() {
291     || ->$0 i32 {
292         let test = "test";
293         42i32
294     };
295 }
296 "#,
297             r#"
298 fn foo() {
299     || -> Result<i32, ${0:_}> {
300         let test = "test";
301         Ok(42i32)
302     };
303 }
304 "#,
305         );
306     }
307
308     #[test]
309     fn wrap_return_type_in_result_simple_with_tail_only() {
310         check_assist(
311             wrap_return_type_in_result,
312             r#"fn foo() -> i32$0 { 42i32 }"#,
313             r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#,
314         );
315     }
316
317     #[test]
318     fn wrap_return_type_in_result_simple_with_tail_block_like() {
319         check_assist(
320             wrap_return_type_in_result,
321             r#"
322 fn foo() -> i32$0 {
323     if true {
324         42i32
325     } else {
326         24i32
327     }
328 }
329 "#,
330             r#"
331 fn foo() -> Result<i32, ${0:_}> {
332     if true {
333         Ok(42i32)
334     } else {
335         Ok(24i32)
336     }
337 }
338 "#,
339         );
340     }
341
342     #[test]
343     fn wrap_return_type_in_result_simple_without_block_closure() {
344         check_assist(
345             wrap_return_type_in_result,
346             r#"
347 fn foo() {
348     || -> i32$0 {
349         if true {
350             42i32
351         } else {
352             24i32
353         }
354     };
355 }
356 "#,
357             r#"
358 fn foo() {
359     || -> Result<i32, ${0:_}> {
360         if true {
361             Ok(42i32)
362         } else {
363             Ok(24i32)
364         }
365     };
366 }
367 "#,
368         );
369     }
370
371     #[test]
372     fn wrap_return_type_in_result_simple_with_nested_if() {
373         check_assist(
374             wrap_return_type_in_result,
375             r#"
376 fn foo() -> i32$0 {
377     if true {
378         if false {
379             1
380         } else {
381             2
382         }
383     } else {
384         24i32
385     }
386 }
387 "#,
388             r#"
389 fn foo() -> Result<i32, ${0:_}> {
390     if true {
391         if false {
392             Ok(1)
393         } else {
394             Ok(2)
395         }
396     } else {
397         Ok(24i32)
398     }
399 }
400 "#,
401         );
402     }
403
404     #[test]
405     fn wrap_return_type_in_result_simple_with_await() {
406         check_assist(
407             wrap_return_type_in_result,
408             r#"
409 async fn foo() -> i$032 {
410     if true {
411         if false {
412             1.await
413         } else {
414             2.await
415         }
416     } else {
417         24i32.await
418     }
419 }
420 "#,
421             r#"
422 async fn foo() -> Result<i32, ${0:_}> {
423     if true {
424         if false {
425             Ok(1.await)
426         } else {
427             Ok(2.await)
428         }
429     } else {
430         Ok(24i32.await)
431     }
432 }
433 "#,
434         );
435     }
436
437     #[test]
438     fn wrap_return_type_in_result_simple_with_array() {
439         check_assist(
440             wrap_return_type_in_result,
441             r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#,
442             r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#,
443         );
444     }
445
446     #[test]
447     fn wrap_return_type_in_result_simple_with_cast() {
448         check_assist(
449             wrap_return_type_in_result,
450             r#"
451 fn foo() -$0> i32 {
452     if true {
453         if false {
454             1 as i32
455         } else {
456             2 as i32
457         }
458     } else {
459         24 as i32
460     }
461 }
462 "#,
463             r#"
464 fn foo() -> Result<i32, ${0:_}> {
465     if true {
466         if false {
467             Ok(1 as i32)
468         } else {
469             Ok(2 as i32)
470         }
471     } else {
472         Ok(24 as i32)
473     }
474 }
475 "#,
476         );
477     }
478
479     #[test]
480     fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
481         check_assist(
482             wrap_return_type_in_result,
483             r#"
484 fn foo() -> i32$0 {
485     let my_var = 5;
486     match my_var {
487         5 => 42i32,
488         _ => 24i32,
489     }
490 }
491 "#,
492             r#"
493 fn foo() -> Result<i32, ${0:_}> {
494     let my_var = 5;
495     match my_var {
496         5 => Ok(42i32),
497         _ => Ok(24i32),
498     }
499 }
500 "#,
501         );
502     }
503
504     #[test]
505     fn wrap_return_type_in_result_simple_with_loop_with_tail() {
506         check_assist(
507             wrap_return_type_in_result,
508             r#"
509 fn foo() -> i32$0 {
510     let my_var = 5;
511     loop {
512         println!("test");
513         5
514     }
515     my_var
516 }
517 "#,
518             r#"
519 fn foo() -> Result<i32, ${0:_}> {
520     let my_var = 5;
521     loop {
522         println!("test");
523         5
524     }
525     Ok(my_var)
526 }
527 "#,
528         );
529     }
530
531     #[test]
532     fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
533         check_assist(
534             wrap_return_type_in_result,
535             r#"
536 fn foo() -> i32$0 {
537     let my_var = let x = loop {
538         break 1;
539     };
540     my_var
541 }
542 "#,
543             r#"
544 fn foo() -> Result<i32, ${0:_}> {
545     let my_var = let x = loop {
546         break 1;
547     };
548     Ok(my_var)
549 }
550 "#,
551         );
552     }
553
554     #[test]
555     fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
556         check_assist(
557             wrap_return_type_in_result,
558             r#"
559 fn foo() -> i32$0 {
560     let my_var = 5;
561     let res = match my_var {
562         5 => 42i32,
563         _ => return 24i32,
564     };
565     res
566 }
567 "#,
568             r#"
569 fn foo() -> Result<i32, ${0:_}> {
570     let my_var = 5;
571     let res = match my_var {
572         5 => 42i32,
573         _ => return Ok(24i32),
574     };
575     Ok(res)
576 }
577 "#,
578         );
579
580         check_assist(
581             wrap_return_type_in_result,
582             r#"
583 fn foo() -> i32$0 {
584     let my_var = 5;
585     let res = if my_var == 5 {
586         42i32
587     } else {
588         return 24i32;
589     };
590     res
591 }
592 "#,
593             r#"
594 fn foo() -> Result<i32, ${0:_}> {
595     let my_var = 5;
596     let res = if my_var == 5 {
597         42i32
598     } else {
599         return Ok(24i32);
600     };
601     Ok(res)
602 }
603 "#,
604         );
605     }
606
607     #[test]
608     fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
609         check_assist(
610             wrap_return_type_in_result,
611             r#"
612 fn foo() -> i32$0 {
613     let my_var = 5;
614     match my_var {
615         5 => {
616             if true {
617                 42i32
618             } else {
619                 25i32
620             }
621         },
622         _ => {
623             let test = "test";
624             if test == "test" {
625                 return bar();
626             }
627             53i32
628         },
629     }
630 }
631 "#,
632             r#"
633 fn foo() -> Result<i32, ${0:_}> {
634     let my_var = 5;
635     match my_var {
636         5 => {
637             if true {
638                 Ok(42i32)
639             } else {
640                 Ok(25i32)
641             }
642         },
643         _ => {
644             let test = "test";
645             if test == "test" {
646                 return Ok(bar());
647             }
648             Ok(53i32)
649         },
650     }
651 }
652 "#,
653         );
654     }
655
656     #[test]
657     fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
658         check_assist(
659             wrap_return_type_in_result,
660             r#"
661 fn foo() -> i$032 {
662     let test = "test";
663     if test == "test" {
664         return 24i32;
665     }
666     53i32
667 }
668 "#,
669             r#"
670 fn foo() -> Result<i32, ${0:_}> {
671     let test = "test";
672     if test == "test" {
673         return Ok(24i32);
674     }
675     Ok(53i32)
676 }
677 "#,
678         );
679     }
680
681     #[test]
682     fn wrap_return_type_in_result_simple_with_closure() {
683         check_assist(
684             wrap_return_type_in_result,
685             r#"
686 fn foo(the_field: u32) ->$0 u32 {
687     let true_closure = || { return true; };
688     if the_field < 5 {
689         let mut i = 0;
690         if true_closure() {
691             return 99;
692         } else {
693             return 0;
694         }
695     }
696     the_field
697 }
698 "#,
699             r#"
700 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
701     let true_closure = || { return true; };
702     if the_field < 5 {
703         let mut i = 0;
704         if true_closure() {
705             return Ok(99);
706         } else {
707             return Ok(0);
708         }
709     }
710     Ok(the_field)
711 }
712 "#,
713         );
714
715         check_assist(
716             wrap_return_type_in_result,
717             r#"
718             fn foo(the_field: u32) -> u32$0 {
719                 let true_closure = || {
720                     return true;
721                 };
722                 if the_field < 5 {
723                     let mut i = 0;
724
725
726                     if true_closure() {
727                         return 99;
728                     } else {
729                         return 0;
730                     }
731                 }
732                 let t = None;
733
734                 t.unwrap_or_else(|| the_field)
735             }
736             "#,
737             r#"
738             fn foo(the_field: u32) -> Result<u32, ${0:_}> {
739                 let true_closure = || {
740                     return true;
741                 };
742                 if the_field < 5 {
743                     let mut i = 0;
744
745
746                     if true_closure() {
747                         return Ok(99);
748                     } else {
749                         return Ok(0);
750                     }
751                 }
752                 let t = None;
753
754                 Ok(t.unwrap_or_else(|| the_field))
755             }
756             "#,
757         );
758     }
759
760     #[test]
761     fn wrap_return_type_in_result_simple_with_weird_forms() {
762         check_assist(
763             wrap_return_type_in_result,
764             r#"
765 fn foo() -> i32$0 {
766     let test = "test";
767     if test == "test" {
768         return 24i32;
769     }
770     let mut i = 0;
771     loop {
772         if i == 1 {
773             break 55;
774         }
775         i += 1;
776     }
777 }
778 "#,
779             r#"
780 fn foo() -> Result<i32, ${0:_}> {
781     let test = "test";
782     if test == "test" {
783         return Ok(24i32);
784     }
785     let mut i = 0;
786     loop {
787         if i == 1 {
788             break Ok(55);
789         }
790         i += 1;
791     }
792 }
793 "#,
794         );
795
796         check_assist(
797             wrap_return_type_in_result,
798             r#"
799 fn foo(the_field: u32) -> u32$0 {
800     if the_field < 5 {
801         let mut i = 0;
802         loop {
803             if i > 5 {
804                 return 55u32;
805             }
806             i += 3;
807         }
808         match i {
809             5 => return 99,
810             _ => return 0,
811         };
812     }
813     the_field
814 }
815 "#,
816             r#"
817 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
818     if the_field < 5 {
819         let mut i = 0;
820         loop {
821             if i > 5 {
822                 return Ok(55u32);
823             }
824             i += 3;
825         }
826         match i {
827             5 => return Ok(99),
828             _ => return Ok(0),
829         };
830     }
831     Ok(the_field)
832 }
833 "#,
834         );
835
836         check_assist(
837             wrap_return_type_in_result,
838             r#"
839 fn foo(the_field: u32) -> u3$02 {
840     if the_field < 5 {
841         let mut i = 0;
842         match i {
843             5 => return 99,
844             _ => return 0,
845         }
846     }
847     the_field
848 }
849 "#,
850             r#"
851 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
852     if the_field < 5 {
853         let mut i = 0;
854         match i {
855             5 => return Ok(99),
856             _ => return Ok(0),
857         }
858     }
859     Ok(the_field)
860 }
861 "#,
862         );
863
864         check_assist(
865             wrap_return_type_in_result,
866             r#"
867 fn foo(the_field: u32) -> u32$0 {
868     if the_field < 5 {
869         let mut i = 0;
870         if i == 5 {
871             return 99
872         } else {
873             return 0
874         }
875     }
876     the_field
877 }
878 "#,
879             r#"
880 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
881     if the_field < 5 {
882         let mut i = 0;
883         if i == 5 {
884             return Ok(99)
885         } else {
886             return Ok(0)
887         }
888     }
889     Ok(the_field)
890 }
891 "#,
892         );
893
894         check_assist(
895             wrap_return_type_in_result,
896             r#"
897 fn foo(the_field: u32) -> $0u32 {
898     if the_field < 5 {
899         let mut i = 0;
900         if i == 5 {
901             return 99;
902         } else {
903             return 0;
904         }
905     }
906     the_field
907 }
908 "#,
909             r#"
910 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
911     if the_field < 5 {
912         let mut i = 0;
913         if i == 5 {
914             return Ok(99);
915         } else {
916             return Ok(0);
917         }
918     }
919     Ok(the_field)
920 }
921 "#,
922         );
923     }
924 }