2 famous_defs::FamousDefs,
3 syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
5 use itertools::Itertools;
8 match_ast, AstNode, TextRange, TextSize,
11 use crate::{AssistContext, AssistId, AssistKind, Assists};
13 // Assist: unwrap_result_return_type
15 // Unwrap the function's return type.
18 // # //- minicore: result
19 // fn foo() -> Result<i32>$0 { Ok(42i32) }
23 // fn foo() -> i32 { 42i32 }
25 pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
26 let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
27 let parent = ret_type.syntax().parent()?;
28 let body = match_ast! {
30 ast::Fn(func) => func.body()?,
31 ast::ClosureExpr(closure) => match closure.body()? {
32 Expr::BlockExpr(block) => block,
33 // closures require a block when a return type is specified
40 let type_ref = &ret_type.ty()?;
41 let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
43 FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
45 if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
50 AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
51 "Unwrap Result return type",
52 type_ref.syntax().text_range(),
54 let body = ast::Expr::BlockExpr(body);
56 let mut exprs_to_unwrap = Vec::new();
57 let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_unwrap, e);
58 walk_expr(&body, &mut |expr| {
59 if let Expr::ReturnExpr(ret_expr) = expr {
60 if let Some(ret_expr_arg) = &ret_expr.expr() {
61 for_each_tail_expr(ret_expr_arg, tail_cb);
65 for_each_tail_expr(&body, tail_cb);
67 let mut is_unit_type = false;
68 if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
69 let inner_type = match inner_type.split_once(',') {
70 Some((success_inner_type, _)) => success_inner_type,
73 let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type);
74 if new_ret_type == "()" {
76 let text_range = TextRange::new(
77 ret_type.syntax().text_range().start(),
78 ret_type.syntax().text_range().end() + TextSize::from(1u32),
80 builder.delete(text_range)
83 type_ref.syntax().text_range(),
84 inner_type.strip_suffix('>').unwrap_or(inner_type),
89 for ret_expr_arg in exprs_to_unwrap {
90 let ret_expr_str = ret_expr_arg.to_string();
91 if ret_expr_str.starts_with("Ok(") || ret_expr_str.starts_with("Err(") {
92 let arg_list = ret_expr_arg.syntax().children().find_map(ast::ArgList::cast);
93 if let Some(arg_list) = arg_list {
95 match ret_expr_arg.syntax().prev_sibling_or_token() {
96 // Useful to delete the entire line without leaving trailing whitespaces
98 let new_range = TextRange::new(
99 whitespace.text_range().start(),
100 ret_expr_arg.syntax().text_range().end(),
102 builder.delete(new_range);
105 builder.delete(ret_expr_arg.syntax().text_range());
110 ret_expr_arg.syntax().text_range(),
111 arg_list.args().join(", "),
121 fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
123 Expr::BreakExpr(break_expr) => {
124 if let Some(break_expr_arg) = break_expr.expr() {
125 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
128 Expr::ReturnExpr(ret_expr) => {
129 if let Some(ret_expr_arg) = &ret_expr.expr() {
130 for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
133 e => acc.push(e.clone()),
139 use crate::tests::{check_assist, check_assist_not_applicable};
144 fn unwrap_result_return_type_simple() {
146 unwrap_result_return_type,
149 fn foo() -> Result<i3$02> {
164 fn unwrap_result_return_type_unit_type() {
166 unwrap_result_return_type,
169 fn foo() -> Result<(), Box<dyn Error$0>> {
181 fn unwrap_result_return_type_ending_with_parent() {
183 unwrap_result_return_type,
186 fn foo() -> Result<i32, Box<dyn Error$0>> {
207 fn unwrap_return_type_break_split_tail() {
209 unwrap_result_return_type,
212 fn foo() -> Result<i3$02, String> {
237 fn unwrap_result_return_type_simple_closure() {
239 unwrap_result_return_type,
243 || -> Result<i32$0> {
261 fn unwrap_result_return_type_simple_return_type_bad_cursor() {
262 check_assist_not_applicable(
263 unwrap_result_return_type,
275 fn unwrap_result_return_type_simple_return_type_bad_cursor_closure() {
276 check_assist_not_applicable(
277 unwrap_result_return_type,
291 fn unwrap_result_return_type_closure_non_block() {
292 check_assist_not_applicable(
293 unwrap_result_return_type,
296 fn foo() { || -> i$032 3; }
302 fn unwrap_result_return_type_simple_return_type_already_not_result_std() {
303 check_assist_not_applicable(
304 unwrap_result_return_type,
316 fn unwrap_result_return_type_simple_return_type_already_not_result_closure() {
317 check_assist_not_applicable(
318 unwrap_result_return_type,
332 fn unwrap_result_return_type_simple_with_tail() {
334 unwrap_result_return_type,
337 fn foo() ->$0 Result<i32> {
352 fn unwrap_result_return_type_simple_with_tail_closure() {
354 unwrap_result_return_type,
358 || ->$0 Result<i32, String> {
376 fn unwrap_result_return_type_simple_with_tail_only() {
378 unwrap_result_return_type,
381 fn foo() -> Result<i32$0> { Ok(42i32) }
384 fn foo() -> i32 { 42i32 }
390 fn unwrap_result_return_type_simple_with_tail_block_like() {
392 unwrap_result_return_type,
395 fn foo() -> Result<i32>$0 {
416 fn unwrap_result_return_type_simple_without_block_closure() {
418 unwrap_result_return_type,
422 || -> Result<i32, String>$0 {
446 fn unwrap_result_return_type_simple_with_nested_if() {
448 unwrap_result_return_type,
451 fn foo() -> Result<i32>$0 {
480 fn unwrap_result_return_type_simple_with_await() {
482 unwrap_result_return_type,
485 async fn foo() -> Result<i$032> {
498 async fn foo() -> i32 {
514 fn unwrap_result_return_type_simple_with_array() {
516 unwrap_result_return_type,
519 fn foo() -> Result<[i32; 3]$0> { Ok([1, 2, 3]) }
522 fn foo() -> [i32; 3] { [1, 2, 3] }
528 fn unwrap_result_return_type_simple_with_cast() {
530 unwrap_result_return_type,
533 fn foo() -$0> Result<i32> {
562 fn unwrap_result_return_type_simple_with_tail_block_like_match() {
564 unwrap_result_return_type,
567 fn foo() -> Result<i32$0> {
588 fn unwrap_result_return_type_simple_with_loop_with_tail() {
590 unwrap_result_return_type,
593 fn foo() -> Result<i32$0> {
616 fn unwrap_result_return_type_simple_with_loop_in_let_stmt() {
618 unwrap_result_return_type,
621 fn foo() -> Result<i32$0> {
622 let my_var = let x = loop {
630 let my_var = let x = loop {
640 fn unwrap_result_return_type_simple_with_tail_block_like_match_return_expr() {
642 unwrap_result_return_type,
645 fn foo() -> Result<i32>$0 {
647 let res = match my_var {
649 _ => return Ok(24i32),
657 let res = match my_var {
667 unwrap_result_return_type,
670 fn foo() -> Result<i32$0> {
672 let res = if my_var == 5 {
683 let res = if my_var == 5 {
695 fn unwrap_result_return_type_simple_with_tail_block_like_match_deeper() {
697 unwrap_result_return_type,
700 fn foo() -> Result<i32$0> {
745 fn unwrap_result_return_type_simple_with_tail_block_like_early_return() {
747 unwrap_result_return_type,
750 fn foo() -> Result<i32$0> {
771 fn unwrap_result_return_type_simple_with_closure() {
773 unwrap_result_return_type,
776 fn foo(the_field: u32) -> Result<u32$0> {
777 let true_closure = || { return true; };
790 fn foo(the_field: u32) -> u32 {
791 let true_closure = || { return true; };
806 unwrap_result_return_type,
809 fn foo(the_field: u32) -> Result<u32$0> {
810 let true_closure = || {
825 Ok(t.unwrap_or_else(|| the_field))
829 fn foo(the_field: u32) -> u32 {
830 let true_closure = || {
845 t.unwrap_or_else(|| the_field)
852 fn unwrap_result_return_type_simple_with_weird_forms() {
854 unwrap_result_return_type,
857 fn foo() -> Result<i32$0> {
889 unwrap_result_return_type,
892 fn foo(the_field: u32) -> Result<u32$0> {
910 fn foo(the_field: u32) -> u32 {
930 unwrap_result_return_type,
933 fn foo(the_field: u32) -> Result<u32$0> {
945 fn foo(the_field: u32) -> u32 {
959 unwrap_result_return_type,
962 fn foo(the_field: u32) -> Result<u32$0> {
975 fn foo(the_field: u32) -> u32 {
990 unwrap_result_return_type,
993 fn foo(the_field: u32) -> Result<u3$02> {
1006 fn foo(the_field: u32) -> u32 {