4 ast::{self, make, BlockExpr, Expr, LoopBodyOwner},
5 match_ast, AstNode, SyntaxNode,
8 use crate::{AssistContext, AssistId, AssistKind, Assists};
10 // Assist: wrap_return_type_in_result
12 // Wrap the function's return type into Result.
15 // fn foo() -> i32$0 { 42i32 }
19 // fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
21 pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
22 let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
23 let parent = ret_type.syntax().parent()?;
24 let block_expr = match_ast! {
26 ast::Fn(func) => func.body()?,
27 ast::ClosureExpr(closure) => match closure.body()? {
28 Expr::BlockExpr(block) => block,
29 // closures require a block when a return type is specified
36 let type_ref = &ret_type.ty()?;
37 let ret_type_str = type_ref.syntax().text().to_string();
38 let first_part_ret_type = ret_type_str.splitn(2, '<').next();
39 if let Some(ret_type_first_part) = first_part_ret_type {
40 if ret_type_first_part.ends_with("Result") {
41 cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
47 AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
48 "Wrap return type in Result",
49 type_ref.syntax().text_range(),
51 let mut tail_return_expr_collector = TailReturnCollector::new();
52 tail_return_expr_collector.collect_jump_exprs(&block_expr, false);
53 tail_return_expr_collector.collect_tail_exprs(&block_expr);
55 for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap {
56 let ok_wrapped = make::expr_call(
57 make::expr_path(make::ext::ident_path("Ok")),
58 make::arg_list(iter::once(ret_expr_arg.clone())),
60 builder.replace_ast(ret_expr_arg, ok_wrapped);
63 match ctx.config.snippet_cap {
65 let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
66 builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
69 .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
75 struct TailReturnCollector {
76 exprs_to_wrap: Vec<ast::Expr>,
79 impl TailReturnCollector {
81 Self { exprs_to_wrap: vec![] }
83 /// Collect all`return` expression
84 fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) {
85 let statements = block_expr.statements();
86 for stmt in statements {
87 let expr = match &stmt {
88 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
89 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
90 ast::Stmt::Item(_) => continue,
92 if let Some(expr) = &expr {
93 self.handle_exprs(expr, collect_break);
97 // Browse tail expressions for each block
98 if let Some(expr) = block_expr.tail_expr() {
99 if let Some(last_exprs) = get_tail_expr_from_block(&expr) {
100 for last_expr in last_exprs {
101 let last_expr = match last_expr {
102 NodeType::Node(expr) => expr,
103 NodeType::Leaf(expr) => expr.syntax().clone(),
106 if let Some(last_expr) = Expr::cast(last_expr.clone()) {
107 self.handle_exprs(&last_expr, collect_break);
108 } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) {
109 let expr_stmt = match &expr_stmt {
110 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
111 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
112 ast::Stmt::Item(_) => None,
114 if let Some(expr) = &expr_stmt {
115 self.handle_exprs(expr, collect_break);
123 fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) {
125 Expr::BlockExpr(block_expr) => {
126 self.collect_jump_exprs(block_expr, collect_break);
128 Expr::ReturnExpr(ret_expr) => {
129 if let Some(ret_expr_arg) = &ret_expr.expr() {
130 self.exprs_to_wrap.push(ret_expr_arg.clone());
133 Expr::BreakExpr(break_expr) if collect_break => {
134 if let Some(break_expr_arg) = &break_expr.expr() {
135 self.exprs_to_wrap.push(break_expr_arg.clone());
138 Expr::IfExpr(if_expr) => {
139 for block in if_expr.blocks() {
140 self.collect_jump_exprs(&block, collect_break);
143 Expr::LoopExpr(loop_expr) => {
144 if let Some(block_expr) = loop_expr.loop_body() {
145 self.collect_jump_exprs(&block_expr, collect_break);
148 Expr::ForExpr(for_expr) => {
149 if let Some(block_expr) = for_expr.loop_body() {
150 self.collect_jump_exprs(&block_expr, collect_break);
153 Expr::WhileExpr(while_expr) => {
154 if let Some(block_expr) = while_expr.loop_body() {
155 self.collect_jump_exprs(&block_expr, collect_break);
158 Expr::MatchExpr(match_expr) => {
159 if let Some(arm_list) = match_expr.match_arm_list() {
160 arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| {
161 self.handle_exprs(&expr, collect_break);
169 fn collect_tail_exprs(&mut self, block: &BlockExpr) {
170 if let Some(expr) = block.tail_expr() {
171 self.handle_exprs(&expr, true);
172 self.fetch_tail_exprs(&expr);
176 fn fetch_tail_exprs(&mut self, expr: &Expr) {
177 if let Some(exprs) = get_tail_expr_from_block(expr) {
178 for node_type in &exprs {
180 NodeType::Leaf(expr) => {
181 self.exprs_to_wrap.push(expr.clone());
183 NodeType::Node(expr) => {
184 if let Some(last_expr) = Expr::cast(expr.clone()) {
185 self.fetch_tail_exprs(&last_expr);
200 /// Get a tail expression inside a block
201 fn get_tail_expr_from_block(expr: &Expr) -> Option<Vec<NodeType>> {
203 Expr::IfExpr(if_expr) => {
204 let mut nodes = vec![];
205 for block in if_expr.blocks() {
206 if let Some(block_expr) = block.tail_expr() {
207 if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) {
208 nodes.extend(tail_exprs);
210 } else if let Some(last_expr) = block.syntax().last_child() {
211 nodes.push(NodeType::Node(last_expr));
213 nodes.push(NodeType::Node(block.syntax().clone()));
218 Expr::LoopExpr(loop_expr) => {
219 loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
221 Expr::ForExpr(for_expr) => {
222 for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
224 Expr::WhileExpr(while_expr) => {
225 while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
227 Expr::BlockExpr(block_expr) => {
228 block_expr.tail_expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())])
230 Expr::MatchExpr(match_expr) => {
231 let arm_list = match_expr.match_arm_list()?;
232 let arms: Vec<NodeType> = arm_list
234 .filter_map(|match_arm| match_arm.expr())
235 .map(|expr| match expr {
236 Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()),
237 Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()),
238 _ => match expr.syntax().last_child() {
239 Some(last_expr) => NodeType::Node(last_expr),
240 None => NodeType::Node(expr.syntax().clone()),
247 Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]),
248 Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]),
256 | Expr::RecordExpr(_)
258 | Expr::MethodCallExpr(_)
262 | Expr::PrefixExpr(_)
266 | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]),
273 use crate::tests::{check_assist, check_assist_not_applicable};
278 fn wrap_return_type_in_result_simple() {
280 wrap_return_type_in_result,
288 fn foo() -> Result<i32, ${0:_}> {
297 fn wrap_return_type_in_result_simple_closure() {
299 wrap_return_type_in_result,
310 || -> Result<i32, ${0:_}> {
320 fn wrap_return_type_in_result_simple_return_type_bad_cursor() {
321 check_assist_not_applicable(
322 wrap_return_type_in_result,
333 fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() {
334 check_assist_not_applicable(
335 wrap_return_type_in_result,
348 fn wrap_return_type_in_result_closure_non_block() {
349 check_assist_not_applicable(wrap_return_type_in_result, r#"fn foo() { || -> i$032 3; }"#);
353 fn wrap_return_type_in_result_simple_return_type_already_result_std() {
354 check_assist_not_applicable(
355 wrap_return_type_in_result,
357 fn foo() -> std::result::Result<i32$0, String> {
366 fn wrap_return_type_in_result_simple_return_type_already_result() {
367 cov_mark::check!(wrap_return_type_in_result_simple_return_type_already_result);
368 check_assist_not_applicable(
369 wrap_return_type_in_result,
371 fn foo() -> Result<i32$0, String> {
380 fn wrap_return_type_in_result_simple_return_type_already_result_closure() {
381 check_assist_not_applicable(
382 wrap_return_type_in_result,
385 || -> Result<i32$0, String> {
395 fn wrap_return_type_in_result_simple_with_cursor() {
397 wrap_return_type_in_result,
405 fn foo() -> Result<i32, ${0:_}> {
414 fn wrap_return_type_in_result_simple_with_tail() {
416 wrap_return_type_in_result,
424 fn foo() -> Result<i32, ${0:_}> {
433 fn wrap_return_type_in_result_simple_with_tail_closure() {
435 wrap_return_type_in_result,
446 || -> Result<i32, ${0:_}> {
456 fn wrap_return_type_in_result_simple_with_tail_only() {
458 wrap_return_type_in_result,
459 r#"fn foo() -> i32$0 { 42i32 }"#,
460 r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#,
465 fn wrap_return_type_in_result_simple_with_tail_block_like() {
467 wrap_return_type_in_result,
478 fn foo() -> Result<i32, ${0:_}> {
490 fn wrap_return_type_in_result_simple_without_block_closure() {
492 wrap_return_type_in_result,
506 || -> Result<i32, ${0:_}> {
519 fn wrap_return_type_in_result_simple_with_nested_if() {
521 wrap_return_type_in_result,
536 fn foo() -> Result<i32, ${0:_}> {
552 fn wrap_return_type_in_result_simple_with_await() {
554 wrap_return_type_in_result,
556 async fn foo() -> i$032 {
569 async fn foo() -> Result<i32, ${0:_}> {
585 fn wrap_return_type_in_result_simple_with_array() {
587 wrap_return_type_in_result,
588 r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#,
589 r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#,
594 fn wrap_return_type_in_result_simple_with_cast() {
596 wrap_return_type_in_result,
611 fn foo() -> Result<i32, ${0:_}> {
627 fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
629 wrap_return_type_in_result,
640 fn foo() -> Result<i32, ${0:_}> {
652 fn wrap_return_type_in_result_simple_with_loop_with_tail() {
654 wrap_return_type_in_result,
666 fn foo() -> Result<i32, ${0:_}> {
679 fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
681 wrap_return_type_in_result,
684 let my_var = let x = loop {
691 fn foo() -> Result<i32, ${0:_}> {
692 let my_var = let x = loop {
702 fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
704 wrap_return_type_in_result,
708 let res = match my_var {
716 fn foo() -> Result<i32, ${0:_}> {
718 let res = match my_var {
720 _ => return Ok(24i32),
728 wrap_return_type_in_result,
732 let res = if my_var == 5 {
741 fn foo() -> Result<i32, ${0:_}> {
743 let res = if my_var == 5 {
755 fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
757 wrap_return_type_in_result,
780 fn foo() -> Result<i32, ${0:_}> {
804 fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
806 wrap_return_type_in_result,
817 fn foo() -> Result<i32, ${0:_}> {
829 fn wrap_return_type_in_result_simple_with_closure() {
831 wrap_return_type_in_result,
833 fn foo(the_field: u32) ->$0 u32 {
834 let true_closure = || { return true; };
847 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
848 let true_closure = || { return true; };
863 wrap_return_type_in_result,
865 fn foo(the_field: u32) -> u32$0 {
866 let true_closure = || {
881 t.unwrap_or_else(|| the_field)
885 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
886 let true_closure = || {
901 Ok(t.unwrap_or_else(|| the_field))
908 fn wrap_return_type_in_result_simple_with_weird_forms() {
910 wrap_return_type_in_result,
927 fn foo() -> Result<i32, ${0:_}> {
944 wrap_return_type_in_result,
963 fn foo() -> Result<i32, ${0:_}> {
982 wrap_return_type_in_result,
988 let res = match other {
1005 fn foo() -> Result<i32, ${0:_}> {
1009 let res = match other {
1028 wrap_return_type_in_result,
1030 fn foo(the_field: u32) -> u32$0 {
1048 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1068 wrap_return_type_in_result,
1070 fn foo(the_field: u32) -> u3$02 {
1082 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1096 wrap_return_type_in_result,
1098 fn foo(the_field: u32) -> u32$0 {
1111 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1126 wrap_return_type_in_result,
1128 fn foo(the_field: u32) -> $0u32 {
1141 fn foo(the_field: u32) -> Result<u32, ${0:_}> {