4 ast::{self, make, BlockExpr, Expr, LoopBodyOwner},
5 match_ast, AstNode, SyntaxNode,
9 use crate::{AssistContext, AssistId, AssistKind, Assists};
11 // Assist: wrap_return_type_in_result
13 // Wrap the function's return type into Result.
16 // fn foo() -> i32$0 { 42i32 }
20 // fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
22 pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
23 let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
24 let parent = ret_type.syntax().parent()?;
25 let block_expr = match_ast! {
27 ast::Fn(func) => func.body()?,
28 ast::ClosureExpr(closure) => match closure.body()? {
29 Expr::BlockExpr(block) => block,
30 // closures require a block when a return type is specified
37 let type_ref = &ret_type.ty()?;
38 let ret_type_str = type_ref.syntax().text().to_string();
39 let first_part_ret_type = ret_type_str.splitn(2, '<').next();
40 if let Some(ret_type_first_part) = first_part_ret_type {
41 if ret_type_first_part.ends_with("Result") {
42 mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
48 AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
49 "Wrap return type in Result",
50 type_ref.syntax().text_range(),
52 let mut tail_return_expr_collector = TailReturnCollector::new();
53 tail_return_expr_collector.collect_jump_exprs(&block_expr, false);
54 tail_return_expr_collector.collect_tail_exprs(&block_expr);
56 for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap {
57 let ok_wrapped = make::expr_call(
58 make::expr_path(make::path_unqualified(make::path_segment(make::name_ref(
61 make::arg_list(iter::once(ret_expr_arg.clone())),
63 builder.replace_ast(ret_expr_arg, ok_wrapped);
66 match ctx.config.snippet_cap {
68 let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
69 builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
72 .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
78 struct TailReturnCollector {
79 exprs_to_wrap: Vec<ast::Expr>,
82 impl TailReturnCollector {
84 Self { exprs_to_wrap: vec![] }
86 /// Collect all`return` expression
87 fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) {
88 let statements = block_expr.statements();
89 for stmt in statements {
90 let expr = match &stmt {
91 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
92 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
93 ast::Stmt::Item(_) => continue,
95 if let Some(expr) = &expr {
96 self.handle_exprs(expr, collect_break);
100 // Browse tail expressions for each block
101 if let Some(expr) = block_expr.tail_expr() {
102 if let Some(last_exprs) = get_tail_expr_from_block(&expr) {
103 for last_expr in last_exprs {
104 let last_expr = match last_expr {
105 NodeType::Node(expr) => expr,
106 NodeType::Leaf(expr) => expr.syntax().clone(),
109 if let Some(last_expr) = Expr::cast(last_expr.clone()) {
110 self.handle_exprs(&last_expr, collect_break);
111 } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) {
112 let expr_stmt = match &expr_stmt {
113 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
114 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
115 ast::Stmt::Item(_) => None,
117 if let Some(expr) = &expr_stmt {
118 self.handle_exprs(expr, collect_break);
126 fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) {
128 Expr::BlockExpr(block_expr) => {
129 self.collect_jump_exprs(&block_expr, collect_break);
131 Expr::ReturnExpr(ret_expr) => {
132 if let Some(ret_expr_arg) = &ret_expr.expr() {
133 self.exprs_to_wrap.push(ret_expr_arg.clone());
136 Expr::BreakExpr(break_expr) if collect_break => {
137 if let Some(break_expr_arg) = &break_expr.expr() {
138 self.exprs_to_wrap.push(break_expr_arg.clone());
141 Expr::IfExpr(if_expr) => {
142 for block in if_expr.blocks() {
143 self.collect_jump_exprs(&block, collect_break);
146 Expr::LoopExpr(loop_expr) => {
147 if let Some(block_expr) = loop_expr.loop_body() {
148 self.collect_jump_exprs(&block_expr, collect_break);
151 Expr::ForExpr(for_expr) => {
152 if let Some(block_expr) = for_expr.loop_body() {
153 self.collect_jump_exprs(&block_expr, collect_break);
156 Expr::WhileExpr(while_expr) => {
157 if let Some(block_expr) = while_expr.loop_body() {
158 self.collect_jump_exprs(&block_expr, collect_break);
161 Expr::MatchExpr(match_expr) => {
162 if let Some(arm_list) = match_expr.match_arm_list() {
163 arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| {
164 self.handle_exprs(&expr, collect_break);
172 fn collect_tail_exprs(&mut self, block: &BlockExpr) {
173 if let Some(expr) = block.tail_expr() {
174 self.handle_exprs(&expr, true);
175 self.fetch_tail_exprs(&expr);
179 fn fetch_tail_exprs(&mut self, expr: &Expr) {
180 if let Some(exprs) = get_tail_expr_from_block(expr) {
181 for node_type in &exprs {
183 NodeType::Leaf(expr) => {
184 self.exprs_to_wrap.push(expr.clone());
186 NodeType::Node(expr) => {
187 if let Some(last_expr) = Expr::cast(expr.clone()) {
188 self.fetch_tail_exprs(&last_expr);
203 /// Get a tail expression inside a block
204 fn get_tail_expr_from_block(expr: &Expr) -> Option<Vec<NodeType>> {
206 Expr::IfExpr(if_expr) => {
207 let mut nodes = vec![];
208 for block in if_expr.blocks() {
209 if let Some(block_expr) = block.tail_expr() {
210 if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) {
211 nodes.extend(tail_exprs);
213 } else if let Some(last_expr) = block.syntax().last_child() {
214 nodes.push(NodeType::Node(last_expr));
216 nodes.push(NodeType::Node(block.syntax().clone()));
221 Expr::LoopExpr(loop_expr) => {
222 loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
224 Expr::ForExpr(for_expr) => {
225 for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
227 Expr::WhileExpr(while_expr) => {
228 while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
230 Expr::BlockExpr(block_expr) => {
231 block_expr.tail_expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())])
233 Expr::MatchExpr(match_expr) => {
234 let arm_list = match_expr.match_arm_list()?;
235 let arms: Vec<NodeType> = arm_list
237 .filter_map(|match_arm| match_arm.expr())
238 .map(|expr| match expr {
239 Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()),
240 Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()),
241 _ => match expr.syntax().last_child() {
242 Some(last_expr) => NodeType::Node(last_expr),
243 None => NodeType::Node(expr.syntax().clone()),
250 Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]),
251 Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]),
259 | Expr::RecordExpr(_)
261 | Expr::MethodCallExpr(_)
265 | Expr::PrefixExpr(_)
269 | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]),
276 use crate::tests::{check_assist, check_assist_not_applicable};
281 fn wrap_return_type_in_result_simple() {
283 wrap_return_type_in_result,
291 fn foo() -> Result<i32, ${0:_}> {
300 fn wrap_return_type_in_result_simple_closure() {
302 wrap_return_type_in_result,
313 || -> Result<i32, ${0:_}> {
323 fn wrap_return_type_in_result_simple_return_type_bad_cursor() {
324 check_assist_not_applicable(
325 wrap_return_type_in_result,
336 fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() {
337 check_assist_not_applicable(
338 wrap_return_type_in_result,
351 fn wrap_return_type_in_result_closure_non_block() {
352 check_assist_not_applicable(wrap_return_type_in_result, r#"fn foo() { || -> i$032 3; }"#);
356 fn wrap_return_type_in_result_simple_return_type_already_result_std() {
357 check_assist_not_applicable(
358 wrap_return_type_in_result,
360 fn foo() -> std::result::Result<i32$0, String> {
369 fn wrap_return_type_in_result_simple_return_type_already_result() {
370 mark::check!(wrap_return_type_in_result_simple_return_type_already_result);
371 check_assist_not_applicable(
372 wrap_return_type_in_result,
374 fn foo() -> Result<i32$0, String> {
383 fn wrap_return_type_in_result_simple_return_type_already_result_closure() {
384 check_assist_not_applicable(
385 wrap_return_type_in_result,
388 || -> Result<i32$0, String> {
398 fn wrap_return_type_in_result_simple_with_cursor() {
400 wrap_return_type_in_result,
408 fn foo() -> Result<i32, ${0:_}> {
417 fn wrap_return_type_in_result_simple_with_tail() {
419 wrap_return_type_in_result,
427 fn foo() -> Result<i32, ${0:_}> {
436 fn wrap_return_type_in_result_simple_with_tail_closure() {
438 wrap_return_type_in_result,
449 || -> Result<i32, ${0:_}> {
459 fn wrap_return_type_in_result_simple_with_tail_only() {
461 wrap_return_type_in_result,
462 r#"fn foo() -> i32$0 { 42i32 }"#,
463 r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#,
468 fn wrap_return_type_in_result_simple_with_tail_block_like() {
470 wrap_return_type_in_result,
481 fn foo() -> Result<i32, ${0:_}> {
493 fn wrap_return_type_in_result_simple_without_block_closure() {
495 wrap_return_type_in_result,
509 || -> Result<i32, ${0:_}> {
522 fn wrap_return_type_in_result_simple_with_nested_if() {
524 wrap_return_type_in_result,
539 fn foo() -> Result<i32, ${0:_}> {
555 fn wrap_return_type_in_result_simple_with_await() {
557 wrap_return_type_in_result,
559 async fn foo() -> i$032 {
572 async fn foo() -> Result<i32, ${0:_}> {
588 fn wrap_return_type_in_result_simple_with_array() {
590 wrap_return_type_in_result,
591 r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#,
592 r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#,
597 fn wrap_return_type_in_result_simple_with_cast() {
599 wrap_return_type_in_result,
614 fn foo() -> Result<i32, ${0:_}> {
630 fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
632 wrap_return_type_in_result,
643 fn foo() -> Result<i32, ${0:_}> {
655 fn wrap_return_type_in_result_simple_with_loop_with_tail() {
657 wrap_return_type_in_result,
669 fn foo() -> Result<i32, ${0:_}> {
682 fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
684 wrap_return_type_in_result,
687 let my_var = let x = loop {
694 fn foo() -> Result<i32, ${0:_}> {
695 let my_var = let x = loop {
705 fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
707 wrap_return_type_in_result,
711 let res = match my_var {
719 fn foo() -> Result<i32, ${0:_}> {
721 let res = match my_var {
723 _ => return Ok(24i32),
731 wrap_return_type_in_result,
735 let res = if my_var == 5 {
744 fn foo() -> Result<i32, ${0:_}> {
746 let res = if my_var == 5 {
758 fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
760 wrap_return_type_in_result,
783 fn foo() -> Result<i32, ${0:_}> {
807 fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
809 wrap_return_type_in_result,
820 fn foo() -> Result<i32, ${0:_}> {
832 fn wrap_return_type_in_result_simple_with_closure() {
834 wrap_return_type_in_result,
836 fn foo(the_field: u32) ->$0 u32 {
837 let true_closure = || { return true; };
850 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
851 let true_closure = || { return true; };
866 wrap_return_type_in_result,
868 fn foo(the_field: u32) -> u32$0 {
869 let true_closure = || {
884 t.unwrap_or_else(|| the_field)
888 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
889 let true_closure = || {
904 Ok(t.unwrap_or_else(|| the_field))
911 fn wrap_return_type_in_result_simple_with_weird_forms() {
913 wrap_return_type_in_result,
930 fn foo() -> Result<i32, ${0:_}> {
947 wrap_return_type_in_result,
966 fn foo() -> Result<i32, ${0:_}> {
985 wrap_return_type_in_result,
991 let res = match other {
1008 fn foo() -> Result<i32, ${0:_}> {
1012 let res = match other {
1031 wrap_return_type_in_result,
1033 fn foo(the_field: u32) -> u32$0 {
1051 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1071 wrap_return_type_in_result,
1073 fn foo(the_field: u32) -> u3$02 {
1085 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1099 wrap_return_type_in_result,
1101 fn foo(the_field: u32) -> u32$0 {
1114 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
1129 wrap_return_type_in_result,
1131 fn foo(the_field: u32) -> $0u32 {
1144 fn foo(the_field: u32) -> Result<u32, ${0:_}> {