3 SyntaxKind::{COMMENT, WHITESPACE},
7 use crate::{Assist, AssistCtx, AssistId};
8 use ast::{BlockExpr, Expr, LoopBodyOwner};
10 // Assist: change_return_type_to_result
12 // Change the function's return type to Result.
15 // fn foo() -> i32<|> { 42i32 }
19 // fn foo() -> Result<i32, > { Ok(42i32) }
21 pub(crate) fn change_return_type_to_result(ctx: AssistCtx) -> Option<Assist> {
22 let fn_def = ctx.find_node_at_offset::<ast::FnDef>();
23 let fn_def = &mut fn_def?;
24 let ret_type = &fn_def.ret_type()?.type_ref()?;
25 if ret_type.syntax().text().to_string().starts_with("Result<") {
29 let block_expr = &fn_def.body()?;
30 let cursor_in_ret_type =
31 fn_def.ret_type()?.syntax().text_range().contains_range(ctx.frange.range);
32 if !cursor_in_ret_type {
37 AssistId("change_return_type_to_result"),
38 "Change return type to Result",
39 ret_type.syntax().text_range(),
41 let mut tail_return_expr_collector = TailReturnCollector::new();
42 tail_return_expr_collector.collect_jump_exprs(block_expr, false);
43 tail_return_expr_collector.collect_tail_exprs(block_expr);
45 for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap {
46 edit.replace_node_and_indent(&ret_expr_arg, format!("Ok({})", ret_expr_arg));
48 edit.replace_node_and_indent(ret_type.syntax(), format!("Result<{}, >", ret_type));
50 if let Some(node_start) = result_insertion_offset(&ret_type) {
51 edit.set_cursor(node_start + TextSize::of(&format!("Result<{}, ", ret_type)));
57 struct TailReturnCollector {
58 exprs_to_wrap: Vec<SyntaxNode>,
61 impl TailReturnCollector {
63 Self { exprs_to_wrap: vec![] }
65 /// Collect all`return` expression
66 fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) {
67 let statements = block_expr.statements();
68 for stmt in statements {
69 let expr = match &stmt {
70 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
71 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
73 if let Some(expr) = &expr {
74 self.handle_exprs(expr, collect_break);
78 // Browse tail expressions for each block
79 if let Some(expr) = block_expr.expr() {
80 if let Some(last_exprs) = get_tail_expr_from_block(&expr) {
81 for last_expr in last_exprs {
82 let last_expr = match last_expr {
83 NodeType::Node(expr) | NodeType::Leaf(expr) => expr,
86 if let Some(last_expr) = Expr::cast(last_expr.clone()) {
87 self.handle_exprs(&last_expr, collect_break);
88 } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) {
89 let expr_stmt = match &expr_stmt {
90 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
91 ast::Stmt::LetStmt(stmt) => stmt.initializer(),
93 if let Some(expr) = &expr_stmt {
94 self.handle_exprs(expr, collect_break);
102 fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) {
104 Expr::BlockExpr(block_expr) => {
105 self.collect_jump_exprs(&block_expr, collect_break);
107 Expr::ReturnExpr(ret_expr) => {
108 if let Some(ret_expr_arg) = &ret_expr.expr() {
109 self.exprs_to_wrap.push(ret_expr_arg.syntax().clone());
112 Expr::BreakExpr(break_expr) if collect_break => {
113 if let Some(break_expr_arg) = &break_expr.expr() {
114 self.exprs_to_wrap.push(break_expr_arg.syntax().clone());
117 Expr::IfExpr(if_expr) => {
118 for block in if_expr.blocks() {
119 self.collect_jump_exprs(&block, collect_break);
122 Expr::LoopExpr(loop_expr) => {
123 if let Some(block_expr) = loop_expr.loop_body() {
124 self.collect_jump_exprs(&block_expr, collect_break);
127 Expr::ForExpr(for_expr) => {
128 if let Some(block_expr) = for_expr.loop_body() {
129 self.collect_jump_exprs(&block_expr, collect_break);
132 Expr::WhileExpr(while_expr) => {
133 if let Some(block_expr) = while_expr.loop_body() {
134 self.collect_jump_exprs(&block_expr, collect_break);
137 Expr::MatchExpr(match_expr) => {
138 if let Some(arm_list) = match_expr.match_arm_list() {
139 arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| {
140 self.handle_exprs(&expr, collect_break);
148 fn collect_tail_exprs(&mut self, block: &BlockExpr) {
149 if let Some(expr) = block.expr() {
150 self.handle_exprs(&expr, true);
151 self.fetch_tail_exprs(&expr);
155 fn fetch_tail_exprs(&mut self, expr: &Expr) {
156 if let Some(exprs) = get_tail_expr_from_block(expr) {
157 for node_type in &exprs {
159 NodeType::Leaf(expr) => {
160 self.exprs_to_wrap.push(expr.clone());
162 NodeType::Node(expr) => match &Expr::cast(expr.clone()) {
164 self.fetch_tail_exprs(last_expr);
167 self.exprs_to_wrap.push(expr.clone());
182 /// Get a tail expression inside a block
183 fn get_tail_expr_from_block(expr: &Expr) -> Option<Vec<NodeType>> {
185 Expr::IfExpr(if_expr) => {
186 let mut nodes = vec![];
187 for block in if_expr.blocks() {
188 if let Some(block_expr) = block.expr() {
189 if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) {
190 nodes.extend(tail_exprs);
192 } else if let Some(last_expr) = block.syntax().last_child() {
193 nodes.push(NodeType::Node(last_expr));
195 nodes.push(NodeType::Node(block.syntax().clone()));
200 Expr::LoopExpr(loop_expr) => {
201 loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
203 Expr::ForExpr(for_expr) => {
204 for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
206 Expr::WhileExpr(while_expr) => {
207 while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
209 Expr::BlockExpr(block_expr) => {
210 block_expr.expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())])
212 Expr::MatchExpr(match_expr) => {
213 let arm_list = match_expr.match_arm_list()?;
214 let arms: Vec<NodeType> = arm_list
216 .filter_map(|match_arm| match_arm.expr())
217 .map(|expr| match expr {
218 Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()),
219 Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()),
220 _ => match expr.syntax().last_child() {
221 Some(last_expr) => NodeType::Node(last_expr),
222 None => NodeType::Node(expr.syntax().clone()),
229 Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e.syntax().clone())]),
230 Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]),
231 Expr::CallExpr(call_expr) => Some(vec![NodeType::Leaf(call_expr.syntax().clone())]),
232 Expr::Literal(lit_expr) => Some(vec![NodeType::Leaf(lit_expr.syntax().clone())]),
233 Expr::TupleExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
234 Expr::ArrayExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
235 Expr::ParenExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
236 Expr::PathExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
237 Expr::Label(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
238 Expr::RecordLit(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
239 Expr::IndexExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
240 Expr::MethodCallExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
241 Expr::AwaitExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
242 Expr::CastExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
243 Expr::RefExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
244 Expr::PrefixExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
245 Expr::RangeExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
246 Expr::BinExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
247 Expr::MacroCall(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
248 Expr::BoxExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]),
253 fn result_insertion_offset(ret_type: &ast::TypeRef) -> Option<TextSize> {
254 let non_ws_child = ret_type
256 .children_with_tokens()
257 .find(|it| it.kind() != COMMENT && it.kind() != WHITESPACE)?;
258 Some(non_ws_child.text_range().start())
264 use crate::tests::{check_assist, check_assist_not_applicable};
269 fn change_return_type_to_result_simple() {
271 change_return_type_to_result,
272 r#"fn foo() -> i3<|>2 {
276 r#"fn foo() -> Result<i32, <|>> {
284 fn change_return_type_to_result_simple_return_type() {
286 change_return_type_to_result,
287 r#"fn foo() -> i32<|> {
291 r#"fn foo() -> Result<i32, <|>> {
299 fn change_return_type_to_result_simple_return_type_bad_cursor() {
300 check_assist_not_applicable(
301 change_return_type_to_result,
303 let test = "test";<|>
310 fn change_return_type_to_result_simple_with_cursor() {
312 change_return_type_to_result,
313 r#"fn foo() -> <|>i32 {
317 r#"fn foo() -> Result<i32, <|>> {
325 fn change_return_type_to_result_simple_with_tail() {
327 change_return_type_to_result,
328 r#"fn foo() -><|> i32 {
332 r#"fn foo() -> Result<i32, <|>> {
340 fn change_return_type_to_result_simple_with_tail_only() {
342 change_return_type_to_result,
343 r#"fn foo() -> i32<|> {
346 r#"fn foo() -> Result<i32, <|>> {
352 fn change_return_type_to_result_simple_with_tail_block_like() {
354 change_return_type_to_result,
355 r#"fn foo() -> i32<|> {
362 r#"fn foo() -> Result<i32, <|>> {
373 fn change_return_type_to_result_simple_with_nested_if() {
375 change_return_type_to_result,
376 r#"fn foo() -> i32<|> {
387 r#"fn foo() -> Result<i32, <|>> {
402 fn change_return_type_to_result_simple_with_await() {
404 change_return_type_to_result,
405 r#"async fn foo() -> i<|>32 {
416 r#"async fn foo() -> Result<i32, <|>> {
431 fn change_return_type_to_result_simple_with_array() {
433 change_return_type_to_result,
434 r#"fn foo() -> [i32;<|> 3] {
437 r#"fn foo() -> Result<[i32; 3], <|>> {
444 fn change_return_type_to_result_simple_with_cast() {
446 change_return_type_to_result,
447 r#"fn foo() -<|>> i32 {
458 r#"fn foo() -> Result<i32, <|>> {
473 fn change_return_type_to_result_simple_with_tail_block_like_match() {
475 change_return_type_to_result,
476 r#"fn foo() -> i32<|> {
483 r#"fn foo() -> Result<i32, <|>> {
494 fn change_return_type_to_result_simple_with_loop_with_tail() {
496 change_return_type_to_result,
497 r#"fn foo() -> i32<|> {
506 r#"fn foo() -> Result<i32, <|>> {
519 fn change_return_type_to_result_simple_with_loop_in_let_stmt() {
521 change_return_type_to_result,
522 r#"fn foo() -> i32<|> {
523 let my_var = let x = loop {
529 r#"fn foo() -> Result<i32, <|>> {
530 let my_var = let x = loop {
540 fn change_return_type_to_result_simple_with_tail_block_like_match_return_expr() {
542 change_return_type_to_result,
543 r#"fn foo() -> i32<|> {
545 let res = match my_var {
552 r#"fn foo() -> Result<i32, <|>> {
554 let res = match my_var {
556 _ => return Ok(24i32),
564 change_return_type_to_result,
565 r#"fn foo() -> i32<|> {
567 let res = if my_var == 5 {
575 r#"fn foo() -> Result<i32, <|>> {
577 let res = if my_var == 5 {
589 fn change_return_type_to_result_simple_with_tail_block_like_match_deeper() {
591 change_return_type_to_result,
592 r#"fn foo() -> i32<|> {
611 r#"fn foo() -> Result<i32, <|>> {
634 fn change_return_type_to_result_simple_with_tail_block_like_early_return() {
636 change_return_type_to_result,
637 r#"fn foo() -> i<|>32 {
644 r#"fn foo() -> Result<i32, <|>> {
655 fn change_return_type_to_result_simple_with_closure() {
657 change_return_type_to_result,
658 r#"fn foo(the_field: u32) -><|> u32 {
659 let true_closure = || {
675 r#"fn foo(the_field: u32) -> Result<u32, <|>> {
676 let true_closure = || {
695 change_return_type_to_result,
696 r#"fn foo(the_field: u32) -> u32<|> {
697 let true_closure = || {
712 t.unwrap_or_else(|| the_field)
714 r#"fn foo(the_field: u32) -> Result<u32, <|>> {
715 let true_closure = || {
730 Ok(t.unwrap_or_else(|| the_field))
736 fn change_return_type_to_result_simple_with_weird_forms() {
738 change_return_type_to_result,
739 r#"fn foo() -> i32<|> {
752 r#"fn foo() -> Result<i32, <|>> {
768 change_return_type_to_result,
769 r#"fn foo() -> i32<|> {
784 r#"fn foo() -> Result<i32, <|>> {
802 change_return_type_to_result,
803 r#"fn foo() -> i3<|>2 {
807 let res = match other {
822 r#"fn foo() -> Result<i32, <|>> {
826 let res = match other {
844 change_return_type_to_result,
845 r#"fn foo(the_field: u32) -> u32<|> {
863 r#"fn foo(the_field: u32) -> Result<u32, <|>> {
884 change_return_type_to_result,
885 r#"fn foo(the_field: u32) -> u3<|>2 {
897 r#"fn foo(the_field: u32) -> Result<u32, <|>> {
912 change_return_type_to_result,
913 r#"fn foo(the_field: u32) -> u32<|> {
926 r#"fn foo(the_field: u32) -> Result<u32, <|>> {
942 change_return_type_to_result,
943 r#"fn foo(the_field: u32) -> <|>u32 {
956 r#"fn foo(the_field: u32) -> Result<u32, <|>> {