4 ast::{self, make, BlockExpr, Expr, LoopBodyOwner},
5 match_ast, AstNode, SyntaxNode,
8 use crate::{AssistContext, AssistId, AssistKind, Assists};
10 // Assist: wrap_return_type_in_result
12 // Wrap the function's return type into Result.
15 // fn foo() -> i32$0 { 42i32 }
19 // fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
21 pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
22 let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
23 let parent = ret_type.syntax().parent()?;
24 let block_expr = match_ast! {
26 ast::Fn(func) => func.body()?,
27 ast::ClosureExpr(closure) => match closure.body()? {
28 Expr::BlockExpr(block) => block,
29 // closures require a block when a return type is specified
36 let type_ref = &ret_type.ty()?;
37 let ret_type_str = type_ref.syntax().text().to_string();
38 let first_part_ret_type = ret_type_str.splitn(2, '<').next();
39 if let Some(ret_type_first_part) = first_part_ret_type {
40 if ret_type_first_part.ends_with("Result") {
41 cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
47 AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
48 "Wrap return type in Result",
49 type_ref.syntax().text_range(),
51 let mut tail_return_expr_collector = TailReturnCollector::new();
52 tail_return_expr_collector.collect_jump_exprs(&block_expr, false);
53 tail_return_expr_collector.collect_tail_exprs(&block_expr);
55 for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap {
56 let ok_wrapped = make::expr_call(
57 make::expr_path(make::path_unqualified(make::path_segment(make::name_ref(
60 make::arg_list(iter::once(ret_expr_arg.clone())),
62 builder.replace_ast(ret_expr_arg, ok_wrapped);
65 match ctx.config.snippet_cap {
67 let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
68 builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
71 .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
77 struct TailReturnCollector {
78 exprs_to_wrap: Vec<ast::Expr>,
81 impl TailReturnCollector {
83 Self { exprs_to_wrap: vec![] }
85 /// Collect all`return` expression
86 fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) {
87 let statements = block_expr.statements();
88 for stmt in statements {
89 let expr = match &stmt {
90 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
91 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
92 ast::Stmt::Item(_) => continue,
94 if let Some(expr) = &expr {
95 self.handle_exprs(expr, collect_break);
99 // Browse tail expressions for each block
100 if let Some(expr) = block_expr.tail_expr() {
101 if let Some(last_exprs) = get_tail_expr_from_block(&expr) {
102 for last_expr in last_exprs {
103 let last_expr = match last_expr {
104 NodeType::Node(expr) => expr,
105 NodeType::Leaf(expr) => expr.syntax().clone(),
108 if let Some(last_expr) = Expr::cast(last_expr.clone()) {
109 self.handle_exprs(&last_expr, collect_break);
110 } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) {
111 let expr_stmt = match &expr_stmt {
112 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
113 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
114 ast::Stmt::Item(_) => None,
116 if let Some(expr) = &expr_stmt {
117 self.handle_exprs(expr, collect_break);
125 fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) {
127 Expr::BlockExpr(block_expr) => {
128 self.collect_jump_exprs(&block_expr, collect_break);
130 Expr::ReturnExpr(ret_expr) => {
131 if let Some(ret_expr_arg) = &ret_expr.expr() {
132 self.exprs_to_wrap.push(ret_expr_arg.clone());
135 Expr::BreakExpr(break_expr) if collect_break => {
136 if let Some(break_expr_arg) = &break_expr.expr() {
137 self.exprs_to_wrap.push(break_expr_arg.clone());
140 Expr::IfExpr(if_expr) => {
141 for block in if_expr.blocks() {
142 self.collect_jump_exprs(&block, collect_break);
145 Expr::LoopExpr(loop_expr) => {
146 if let Some(block_expr) = loop_expr.loop_body() {
147 self.collect_jump_exprs(&block_expr, collect_break);
150 Expr::ForExpr(for_expr) => {
151 if let Some(block_expr) = for_expr.loop_body() {
152 self.collect_jump_exprs(&block_expr, collect_break);
155 Expr::WhileExpr(while_expr) => {
156 if let Some(block_expr) = while_expr.loop_body() {
157 self.collect_jump_exprs(&block_expr, collect_break);
160 Expr::MatchExpr(match_expr) => {
161 if let Some(arm_list) = match_expr.match_arm_list() {
162 arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| {
163 self.handle_exprs(&expr, collect_break);
171 fn collect_tail_exprs(&mut self, block: &BlockExpr) {
172 if let Some(expr) = block.tail_expr() {
173 self.handle_exprs(&expr, true);
174 self.fetch_tail_exprs(&expr);
178 fn fetch_tail_exprs(&mut self, expr: &Expr) {
179 if let Some(exprs) = get_tail_expr_from_block(expr) {
180 for node_type in &exprs {
182 NodeType::Leaf(expr) => {
183 self.exprs_to_wrap.push(expr.clone());
185 NodeType::Node(expr) => {
186 if let Some(last_expr) = Expr::cast(expr.clone()) {
187 self.fetch_tail_exprs(&last_expr);
202 /// Get a tail expression inside a block
203 fn get_tail_expr_from_block(expr: &Expr) -> Option<Vec<NodeType>> {
205 Expr::IfExpr(if_expr) => {
206 let mut nodes = vec![];
207 for block in if_expr.blocks() {
208 if let Some(block_expr) = block.tail_expr() {
209 if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) {
210 nodes.extend(tail_exprs);
212 } else if let Some(last_expr) = block.syntax().last_child() {
213 nodes.push(NodeType::Node(last_expr));
215 nodes.push(NodeType::Node(block.syntax().clone()));
220 Expr::LoopExpr(loop_expr) => {
221 loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
223 Expr::ForExpr(for_expr) => {
224 for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
226 Expr::WhileExpr(while_expr) => {
227 while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
229 Expr::BlockExpr(block_expr) => {
230 block_expr.tail_expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())])
232 Expr::MatchExpr(match_expr) => {
233 let arm_list = match_expr.match_arm_list()?;
234 let arms: Vec<NodeType> = arm_list
236 .filter_map(|match_arm| match_arm.expr())
237 .map(|expr| match expr {
238 Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()),
239 Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()),
240 _ => match expr.syntax().last_child() {
241 Some(last_expr) => NodeType::Node(last_expr),
242 None => NodeType::Node(expr.syntax().clone()),
249 Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]),
250 Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]),
258 | Expr::RecordExpr(_)
260 | Expr::MethodCallExpr(_)
264 | Expr::PrefixExpr(_)
268 | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]),
275 use crate::tests::{check_assist, check_assist_not_applicable};
280 fn wrap_return_type_in_result_simple() {
282 wrap_return_type_in_result,
290 fn foo() -> Result<i32, ${0:_}> {
299 fn wrap_return_type_in_result_simple_closure() {
301 wrap_return_type_in_result,
312 || -> Result<i32, ${0:_}> {
322 fn wrap_return_type_in_result_simple_return_type_bad_cursor() {
323 check_assist_not_applicable(
324 wrap_return_type_in_result,
335 fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() {
336 check_assist_not_applicable(
337 wrap_return_type_in_result,
350 fn wrap_return_type_in_result_closure_non_block() {
351 check_assist_not_applicable(wrap_return_type_in_result, r#"fn foo() { || -> i$032 3; }"#);
355 fn wrap_return_type_in_result_simple_return_type_already_result_std() {
356 check_assist_not_applicable(
357 wrap_return_type_in_result,
359 fn foo() -> std::result::Result<i32$0, String> {
368 fn wrap_return_type_in_result_simple_return_type_already_result() {
369 cov_mark::check!(wrap_return_type_in_result_simple_return_type_already_result);
370 check_assist_not_applicable(
371 wrap_return_type_in_result,
373 fn foo() -> Result<i32$0, String> {
382 fn wrap_return_type_in_result_simple_return_type_already_result_closure() {
383 check_assist_not_applicable(
384 wrap_return_type_in_result,
387 || -> Result<i32$0, String> {
397 fn wrap_return_type_in_result_simple_with_cursor() {
399 wrap_return_type_in_result,
407 fn foo() -> Result<i32, ${0:_}> {
416 fn wrap_return_type_in_result_simple_with_tail() {
418 wrap_return_type_in_result,
426 fn foo() -> Result<i32, ${0:_}> {
435 fn wrap_return_type_in_result_simple_with_tail_closure() {
437 wrap_return_type_in_result,
448 || -> Result<i32, ${0:_}> {
458 fn wrap_return_type_in_result_simple_with_tail_only() {
460 wrap_return_type_in_result,
461 r#"fn foo() -> i32$0 { 42i32 }"#,
462 r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#,
467 fn wrap_return_type_in_result_simple_with_tail_block_like() {
469 wrap_return_type_in_result,
480 fn foo() -> Result<i32, ${0:_}> {
492 fn wrap_return_type_in_result_simple_without_block_closure() {
494 wrap_return_type_in_result,
508 || -> Result<i32, ${0:_}> {
521 fn wrap_return_type_in_result_simple_with_nested_if() {
523 wrap_return_type_in_result,
538 fn foo() -> Result<i32, ${0:_}> {
554 fn wrap_return_type_in_result_simple_with_await() {
556 wrap_return_type_in_result,
558 async fn foo() -> i$032 {
571 async fn foo() -> Result<i32, ${0:_}> {
587 fn wrap_return_type_in_result_simple_with_array() {
589 wrap_return_type_in_result,
590 r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#,
591 r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#,
596 fn wrap_return_type_in_result_simple_with_cast() {
598 wrap_return_type_in_result,
613 fn foo() -> Result<i32, ${0:_}> {
629 fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
631 wrap_return_type_in_result,
642 fn foo() -> Result<i32, ${0:_}> {
654 fn wrap_return_type_in_result_simple_with_loop_with_tail() {
656 wrap_return_type_in_result,
668 fn foo() -> Result<i32, ${0:_}> {
681 fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
683 wrap_return_type_in_result,
686 let my_var = let x = loop {
693 fn foo() -> Result<i32, ${0:_}> {
694 let my_var = let x = loop {
704 fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
706 wrap_return_type_in_result,
710 let res = match my_var {
718 fn foo() -> Result<i32, ${0:_}> {
720 let res = match my_var {
722 _ => return Ok(24i32),
730 wrap_return_type_in_result,
734 let res = if my_var == 5 {
743 fn foo() -> Result<i32, ${0:_}> {
745 let res = if my_var == 5 {
757 fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
759 wrap_return_type_in_result,
782 fn foo() -> Result<i32, ${0:_}> {
806 fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
808 wrap_return_type_in_result,
819 fn foo() -> Result<i32, ${0:_}> {
831 fn wrap_return_type_in_result_simple_with_closure() {
833 wrap_return_type_in_result,
835 fn foo(the_field: u32) ->$0 u32 {
836 let true_closure = || { return true; };
849 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
850 let true_closure = || { return true; };
865 wrap_return_type_in_result,
867 fn foo(the_field: u32) -> u32$0 {
868 let true_closure = || {
883 t.unwrap_or_else(|| the_field)
887 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
888 let true_closure = || {
903 Ok(t.unwrap_or_else(|| the_field))
910 fn wrap_return_type_in_result_simple_with_weird_forms() {
912 wrap_return_type_in_result,
929 fn foo() -> Result<i32, ${0:_}> {
946 wrap_return_type_in_result,
965 fn foo() -> Result<i32, ${0:_}> {
984 wrap_return_type_in_result,
990 let res = match other {
1007 fn foo() -> Result<i32, ${0:_}> {
1011 let res = match other {
1030 wrap_return_type_in_result,
1032 fn foo(the_field: u32) -> u32$0 {
1050 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1070 wrap_return_type_in_result,
1072 fn foo(the_field: u32) -> u3$02 {
1084 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1098 wrap_return_type_in_result,
1100 fn foo(the_field: u32) -> u32$0 {
1113 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1128 wrap_return_type_in_result,
1130 fn foo(the_field: u32) -> $0u32 {
1143 fn foo(the_field: u32) -> Result<u32, ${0:_}> {