use std::iter;
+use ide_db::helpers::{for_each_tail_expr, node_ext::walk_expr, FamousDefs};
use syntax::{
- ast::{self, make, BlockExpr, Expr, LoopBodyOwner},
- match_ast, AstNode, SyntaxNode,
+ ast::{self, make, Expr},
+ match_ast, AstNode,
};
use crate::{AssistContext, AssistId, AssistKind, Assists};
// Wrap the function's return type into Result.
//
// ```
+// # //- minicore: result
// fn foo() -> i32$0 { 42i32 }
// ```
// ->
pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
let parent = ret_type.syntax().parent()?;
- let block_expr = match_ast! {
+ let body = match_ast! {
match parent {
ast::Fn(func) => func.body()?,
ast::ClosureExpr(closure) => match closure.body()? {
};
let type_ref = &ret_type.ty()?;
- let ret_type_str = type_ref.syntax().text().to_string();
- let first_part_ret_type = ret_type_str.splitn(2, '<').next();
- if let Some(ret_type_first_part) = first_part_ret_type {
- if ret_type_first_part.ends_with("Result") {
- cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
- return None;
- }
+ let ty = ctx.sema.resolve_type(type_ref).and_then(|ty| ty.as_adt());
+ let result_enum =
+ FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax()).krate()).core_result_Result()?;
+
+ if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
+ cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
+ return None;
}
acc.add(
"Wrap return type in Result",
type_ref.syntax().text_range(),
|builder| {
- let mut tail_return_expr_collector = TailReturnCollector::new();
- tail_return_expr_collector.collect_jump_exprs(&block_expr, false);
- tail_return_expr_collector.collect_tail_exprs(&block_expr);
+ let body = ast::Expr::BlockExpr(body);
+
+ let mut exprs_to_wrap = Vec::new();
+ let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e);
+ walk_expr(&body, &mut |expr| {
+ if let Expr::ReturnExpr(ret_expr) = expr {
+ if let Some(ret_expr_arg) = &ret_expr.expr() {
+ for_each_tail_expr(ret_expr_arg, tail_cb);
+ }
+ }
+ });
+ for_each_tail_expr(&body, tail_cb);
- for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap {
+ for ret_expr_arg in exprs_to_wrap {
let ok_wrapped = make::expr_call(
make::expr_path(make::ext::ident_path("Ok")),
make::arg_list(iter::once(ret_expr_arg.clone())),
)
}
-struct TailReturnCollector {
- exprs_to_wrap: Vec<ast::Expr>,
-}
-
-impl TailReturnCollector {
- fn new() -> Self {
- Self { exprs_to_wrap: vec![] }
- }
- /// Collect all`return` expression
- fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) {
- let statements = block_expr.statements();
- for stmt in statements {
- let expr = match &stmt {
- ast::Stmt::ExprStmt(stmt) => stmt.expr(),
- ast::Stmt::LetStmt(stmt) => stmt.initializer(),
- ast::Stmt::Item(_) => continue,
- };
- if let Some(expr) = &expr {
- self.handle_exprs(expr, collect_break);
+fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
+ match e {
+ Expr::BreakExpr(break_expr) => {
+ if let Some(break_expr_arg) = break_expr.expr() {
+ for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
}
}
-
- // Browse tail expressions for each block
- if let Some(expr) = block_expr.tail_expr() {
- if let Some(last_exprs) = get_tail_expr_from_block(&expr) {
- for last_expr in last_exprs {
- let last_expr = match last_expr {
- NodeType::Node(expr) => expr,
- NodeType::Leaf(expr) => expr.syntax().clone(),
- };
-
- if let Some(last_expr) = Expr::cast(last_expr.clone()) {
- self.handle_exprs(&last_expr, collect_break);
- } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) {
- let expr_stmt = match &expr_stmt {
- ast::Stmt::ExprStmt(stmt) => stmt.expr(),
- ast::Stmt::LetStmt(stmt) => stmt.initializer(),
- ast::Stmt::Item(_) => None,
- };
- if let Some(expr) = &expr_stmt {
- self.handle_exprs(expr, collect_break);
- }
- }
- }
- }
- }
- }
-
- fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) {
- match expr {
- Expr::BlockExpr(block_expr) => {
- self.collect_jump_exprs(&block_expr, collect_break);
- }
- Expr::ReturnExpr(ret_expr) => {
- if let Some(ret_expr_arg) = &ret_expr.expr() {
- self.exprs_to_wrap.push(ret_expr_arg.clone());
- }
- }
- Expr::BreakExpr(break_expr) if collect_break => {
- if let Some(break_expr_arg) = &break_expr.expr() {
- self.exprs_to_wrap.push(break_expr_arg.clone());
- }
- }
- Expr::IfExpr(if_expr) => {
- for block in if_expr.blocks() {
- self.collect_jump_exprs(&block, collect_break);
- }
- }
- Expr::LoopExpr(loop_expr) => {
- if let Some(block_expr) = loop_expr.loop_body() {
- self.collect_jump_exprs(&block_expr, collect_break);
- }
- }
- Expr::ForExpr(for_expr) => {
- if let Some(block_expr) = for_expr.loop_body() {
- self.collect_jump_exprs(&block_expr, collect_break);
- }
- }
- Expr::WhileExpr(while_expr) => {
- if let Some(block_expr) = while_expr.loop_body() {
- self.collect_jump_exprs(&block_expr, collect_break);
- }
- }
- Expr::MatchExpr(match_expr) => {
- if let Some(arm_list) = match_expr.match_arm_list() {
- arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| {
- self.handle_exprs(&expr, collect_break);
- });
- }
- }
- _ => {}
- }
- }
-
- fn collect_tail_exprs(&mut self, block: &BlockExpr) {
- if let Some(expr) = block.tail_expr() {
- self.handle_exprs(&expr, true);
- self.fetch_tail_exprs(&expr);
- }
- }
-
- fn fetch_tail_exprs(&mut self, expr: &Expr) {
- if let Some(exprs) = get_tail_expr_from_block(expr) {
- for node_type in &exprs {
- match node_type {
- NodeType::Leaf(expr) => {
- self.exprs_to_wrap.push(expr.clone());
- }
- NodeType::Node(expr) => {
- if let Some(last_expr) = Expr::cast(expr.clone()) {
- self.fetch_tail_exprs(&last_expr);
- }
- }
- }
+ Expr::ReturnExpr(ret_expr) => {
+ if let Some(ret_expr_arg) = &ret_expr.expr() {
+ for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
}
}
- }
-}
-
-#[derive(Debug)]
-enum NodeType {
- Leaf(ast::Expr),
- Node(SyntaxNode),
-}
-
-/// Get a tail expression inside a block
-fn get_tail_expr_from_block(expr: &Expr) -> Option<Vec<NodeType>> {
- match expr {
- Expr::IfExpr(if_expr) => {
- let mut nodes = vec![];
- for block in if_expr.blocks() {
- if let Some(block_expr) = block.tail_expr() {
- if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) {
- nodes.extend(tail_exprs);
- }
- } else if let Some(last_expr) = block.syntax().last_child() {
- nodes.push(NodeType::Node(last_expr));
- } else {
- nodes.push(NodeType::Node(block.syntax().clone()));
- }
- }
- Some(nodes)
- }
- Expr::LoopExpr(loop_expr) => {
- loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
- }
- Expr::ForExpr(for_expr) => {
- for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
- }
- Expr::WhileExpr(while_expr) => {
- while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)])
- }
- Expr::BlockExpr(block_expr) => {
- block_expr.tail_expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())])
- }
- Expr::MatchExpr(match_expr) => {
- let arm_list = match_expr.match_arm_list()?;
- let arms: Vec<NodeType> = arm_list
- .arms()
- .filter_map(|match_arm| match_arm.expr())
- .map(|expr| match expr {
- Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()),
- Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()),
- _ => match expr.syntax().last_child() {
- Some(last_expr) => NodeType::Node(last_expr),
- None => NodeType::Node(expr.syntax().clone()),
- },
- })
- .collect();
-
- Some(arms)
- }
- Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]),
- Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]),
-
- Expr::CallExpr(_)
- | Expr::Literal(_)
- | Expr::TupleExpr(_)
- | Expr::ArrayExpr(_)
- | Expr::ParenExpr(_)
- | Expr::PathExpr(_)
- | Expr::RecordExpr(_)
- | Expr::IndexExpr(_)
- | Expr::MethodCallExpr(_)
- | Expr::AwaitExpr(_)
- | Expr::CastExpr(_)
- | Expr::RefExpr(_)
- | Expr::PrefixExpr(_)
- | Expr::RangeExpr(_)
- | Expr::BinExpr(_)
- | Expr::MacroCall(_)
- | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]),
- _ => None,
+ e => acc.push(e.clone()),
}
}
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i3$02 {
let test = "test";
return 42i32;
);
}
+ #[test]
+ fn wrap_return_type_break_split_tail() {
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+fn foo() -> i3$02 {
+ loop {
+ break if true {
+ 1
+ } else {
+ 0
+ };
+ }
+}
+"#,
+ r#"
+fn foo() -> Result<i32, ${0:_}> {
+ loop {
+ break if true {
+ Ok(1)
+ } else {
+ Ok(0)
+ };
+ }
+}
+"#,
+ );
+ }
+
#[test]
fn wrap_return_type_in_result_simple_closure() {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() {
|| -> i32$0 {
let test = "test";
check_assist_not_applicable(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32 {
let test = "test";$0
return 42i32;
check_assist_not_applicable(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() {
|| -> i32 {
let test = "test";$0
#[test]
fn wrap_return_type_in_result_closure_non_block() {
- check_assist_not_applicable(wrap_return_type_in_result, r#"fn foo() { || -> i$032 3; }"#);
+ check_assist_not_applicable(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+fn foo() { || -> i$032 3; }
+"#,
+ );
}
#[test]
check_assist_not_applicable(
wrap_return_type_in_result,
r#"
-fn foo() -> std::result::Result<i32$0, String> {
+//- minicore: result
+fn foo() -> core::result::Result<i32$0, String> {
let test = "test";
return 42i32;
}
check_assist_not_applicable(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> Result<i32$0, String> {
let test = "test";
return 42i32;
check_assist_not_applicable(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() {
|| -> Result<i32$0, String> {
let test = "test";
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> $0i32 {
let test = "test";
return 42i32;
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() ->$0 i32 {
let test = "test";
42i32
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() {
|| ->$0 i32 {
let test = "test";
fn wrap_return_type_in_result_simple_with_tail_only() {
check_assist(
wrap_return_type_in_result,
- r#"fn foo() -> i32$0 { 42i32 }"#,
- r#"fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }"#,
+ r#"
+//- minicore: result
+fn foo() -> i32$0 { 42i32 }
+"#,
+ r#"
+fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
+"#,
);
}
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
if true {
42i32
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() {
|| -> i32$0 {
if true {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
if true {
if false {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
async fn foo() -> i$032 {
if true {
if false {
fn wrap_return_type_in_result_simple_with_array() {
check_assist(
wrap_return_type_in_result,
- r#"fn foo() -> [i32;$0 3] { [1, 2, 3] }"#,
- r#"fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }"#,
+ r#"
+//- minicore: result
+fn foo() -> [i32;$0 3] { [1, 2, 3] }
+"#,
+ r#"
+fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }
+"#,
);
}
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -$0> i32 {
if true {
if false {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
let my_var = 5;
match my_var {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
let my_var = 5;
loop {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
let my_var = let x = loop {
break 1;
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
let my_var = 5;
let res = match my_var {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
let my_var = 5;
let res = if my_var == 5 {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
let my_var = 5;
match my_var {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i$032 {
let test = "test";
if test == "test" {
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo(the_field: u32) ->$0 u32 {
let true_closure = || { return true; };
if the_field < 5 {
check_assist(
wrap_return_type_in_result,
r#"
- fn foo(the_field: u32) -> u32$0 {
- let true_closure = || {
- return true;
- };
- if the_field < 5 {
- let mut i = 0;
+//- minicore: result
+fn foo(the_field: u32) -> u32$0 {
+ let true_closure = || {
+ return true;
+ };
+ if the_field < 5 {
+ let mut i = 0;
- if true_closure() {
- return 99;
- } else {
- return 0;
- }
- }
- let t = None;
+ if true_closure() {
+ return 99;
+ } else {
+ return 0;
+ }
+ }
+ let t = None;
- t.unwrap_or_else(|| the_field)
- }
- "#,
+ t.unwrap_or_else(|| the_field)
+}
+"#,
r#"
- fn foo(the_field: u32) -> Result<u32, ${0:_}> {
- let true_closure = || {
- return true;
- };
- if the_field < 5 {
- let mut i = 0;
+fn foo(the_field: u32) -> Result<u32, ${0:_}> {
+ let true_closure = || {
+ return true;
+ };
+ if the_field < 5 {
+ let mut i = 0;
- if true_closure() {
- return Ok(99);
- } else {
- return Ok(0);
- }
- }
- let t = None;
+ if true_closure() {
+ return Ok(99);
+ } else {
+ return Ok(0);
+ }
+ }
+ let t = None;
- Ok(t.unwrap_or_else(|| the_field))
- }
- "#,
+ Ok(t.unwrap_or_else(|| the_field))
+}
+"#,
);
}
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo() -> i32$0 {
let test = "test";
if test == "test" {
check_assist(
wrap_return_type_in_result,
r#"
-fn foo() -> i32$0 {
- let test = "test";
- if test == "test" {
- return 24i32;
- }
- let mut i = 0;
- loop {
- loop {
- if i == 1 {
- break 55;
- }
- i += 1;
- }
- }
-}
-"#,
- r#"
-fn foo() -> Result<i32, ${0:_}> {
- let test = "test";
- if test == "test" {
- return Ok(24i32);
- }
- let mut i = 0;
- loop {
- loop {
- if i == 1 {
- break Ok(55);
- }
- i += 1;
- }
- }
-}
-"#,
- );
-
- check_assist(
- wrap_return_type_in_result,
- r#"
-fn foo() -> i3$02 {
- let test = "test";
- let other = 5;
- if test == "test" {
- let res = match other {
- 5 => 43,
- _ => return 56,
- };
- }
- let mut i = 0;
- loop {
- loop {
- if i == 1 {
- break 55;
- }
- i += 1;
- }
- }
-}
-"#,
- r#"
-fn foo() -> Result<i32, ${0:_}> {
- let test = "test";
- let other = 5;
- if test == "test" {
- let res = match other {
- 5 => 43,
- _ => return Ok(56),
- };
- }
- let mut i = 0;
- loop {
- loop {
- if i == 1 {
- break Ok(55);
- }
- i += 1;
- }
- }
-}
-"#,
- );
-
- check_assist(
- wrap_return_type_in_result,
- r#"
+//- minicore: result
fn foo(the_field: u32) -> u32$0 {
if the_field < 5 {
let mut i = 0;
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo(the_field: u32) -> u3$02 {
if the_field < 5 {
let mut i = 0;
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo(the_field: u32) -> u32$0 {
if the_field < 5 {
let mut i = 0;
check_assist(
wrap_return_type_in_result,
r#"
+//- minicore: result
fn foo(the_field: u32) -> $0u32 {
if the_field < 5 {
let mut i = 0;