]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/inline_call.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / inline_call.rs
index bd566ec82087d8a88fff714fbc08ab2771e950ec..d88e3fdcd32ae6d51973cb9007fc6b7cd0276f0a 100644 (file)
@@ -1,18 +1,18 @@
 use ast::make;
 use either::Either;
-use hir::{db::HirDatabase, HasSource, PathResolution, Semantics, TypeInfo};
+use hir::{db::HirDatabase, PathResolution, Semantics, TypeInfo};
 use ide_db::{
     base_db::{FileId, FileRange},
     defs::Definition,
-    helpers::insert_use::remove_path_if_in_use_stmt,
+    helpers::{insert_use::remove_path_if_in_use_stmt, node_ext::expr_as_name_ref},
     path_transform::PathTransform,
     search::{FileReference, SearchScope},
     RootDatabase,
 };
 use itertools::{izip, Itertools};
 use syntax::{
-    ast::{self, edit_in_place::Indent, ArgListOwner},
-    ted, AstNode, SyntaxNode,
+    ast::{self, edit_in_place::Indent, HasArgList, PathExpr},
+    ted, AstNode,
 };
 
 use crate::{
@@ -59,7 +59,7 @@
 // }
 // ```
 pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
-    let def_file = ctx.frange.file_id;
+    let def_file = ctx.file_id();
     let name = ctx.find_node_at_offset::<ast::Name>()?;
     let ast_func = name.syntax().parent().and_then(ast::Fn::cast)?;
     let func_body = ast_func.body()?;
@@ -69,7 +69,7 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Opt
 
     let params = get_fn_params(ctx.sema.db, function, &param_list)?;
 
-    let usages = Definition::ModuleDef(hir::ModuleDef::Function(function)).usages(&ctx.sema);
+    let usages = Definition::Function(function).usages(&ctx.sema);
     if !usages.at_least_one() {
         return None;
     }
@@ -141,10 +141,9 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Opt
             for (file_id, refs) in usages.into_iter() {
                 inline_refs_for_file(file_id, refs);
             }
-            if let Some(refs) = current_file_usage {
-                inline_refs_for_file(def_file, refs);
-            } else {
-                builder.edit_file(def_file);
+            match current_file_usage {
+                Some(refs) => inline_refs_for_file(def_file, refs),
+                None => builder.edit_file(def_file),
             }
             if remove_def {
                 builder.delete(ast_func.syntax().text_range());
@@ -178,7 +177,7 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
     let name_ref: ast::NameRef = ctx.find_node_at_offset()?;
     let call_info = CallInfo::from_name_ref(name_ref.clone())?;
     let (function, label) = match &call_info.node {
-        CallExprNode::Call(call) => {
+        ast::CallableExpr::Call(call) => {
             let path = match call.expr()? {
                 ast::Expr::PathExpr(path) => path.path(),
                 _ => None,
@@ -190,17 +189,17 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
             };
             (function, format!("Inline `{}`", path))
         }
-        CallExprNode::MethodCallExpr(call) => {
+        ast::CallableExpr::MethodCall(call) => {
             (ctx.sema.resolve_method_call(call)?, format!("Inline `{}`", name_ref))
         }
     };
 
-    let fn_source = function.source(ctx.db())?;
+    let fn_source = ctx.sema.source(function)?;
     let fn_body = fn_source.value.body()?;
     let param_list = fn_source.value.param_list()?;
 
     let FileRange { file_id, range } = fn_source.syntax().original_file_range(ctx.sema.db);
-    if file_id == ctx.frange.file_id && range.contains(ctx.frange.range.start()) {
+    if file_id == ctx.file_id() && range.contains(ctx.offset()) {
         cov_mark::hit!(inline_call_recursive);
         return None;
     }
@@ -223,8 +222,8 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
 
             builder.replace_ast(
                 match call_info.node {
-                    CallExprNode::Call(it) => ast::Expr::CallExpr(it),
-                    CallExprNode::MethodCallExpr(it) => ast::Expr::MethodCallExpr(it),
+                    ast::CallableExpr::Call(it) => ast::Expr::CallExpr(it),
+                    ast::CallableExpr::MethodCall(it) => ast::Expr::MethodCallExpr(it),
                 },
                 replacement,
             );
@@ -232,22 +231,8 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
     )
 }
 
-enum CallExprNode {
-    Call(ast::CallExpr),
-    MethodCallExpr(ast::MethodCallExpr),
-}
-
-impl CallExprNode {
-    fn syntax(&self) -> &SyntaxNode {
-        match self {
-            CallExprNode::Call(it) => it.syntax(),
-            CallExprNode::MethodCallExpr(it) => it.syntax(),
-        }
-    }
-}
-
 struct CallInfo {
-    node: CallExprNode,
+    node: ast::CallableExpr,
     arguments: Vec<ast::Expr>,
     generic_arg_list: Option<ast::GenericArgList>,
 }
@@ -261,7 +246,7 @@ fn from_name_ref(name_ref: ast::NameRef) -> Option<CallInfo> {
             arguments.extend(call.arg_list()?.args());
             Some(CallInfo {
                 generic_arg_list: call.generic_arg_list(),
-                node: CallExprNode::MethodCallExpr(call),
+                node: ast::CallableExpr::MethodCall(call),
                 arguments,
             })
         } else if let Some(segment) = ast::PathSegment::cast(parent) {
@@ -271,7 +256,7 @@ fn from_name_ref(name_ref: ast::NameRef) -> Option<CallInfo> {
 
             Some(CallInfo {
                 arguments: call.arg_list()?.args().collect(),
-                node: CallExprNode::Call(call),
+                node: ast::CallableExpr::Call(call),
                 generic_arg_list: segment.generic_arg_list(),
             })
         } else {
@@ -359,11 +344,18 @@ fn inline(
     }
     // Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
     for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments).rev() {
-        let expr_is_name_ref = matches!(&expr,
-            ast::Expr::PathExpr(expr)
-                if expr.path().and_then(|path| path.as_single_name_ref()).is_some()
-        );
-        match &*usages {
+        let inline_direct = |usage, replacement: &ast::Expr| {
+            if let Some(field) = path_expr_as_record_field(usage) {
+                cov_mark::hit!(inline_call_inline_direct_field);
+                field.replace_expr(replacement.clone_for_update());
+            } else {
+                ted::replace(usage.syntax(), &replacement.syntax().clone_for_update());
+            }
+        };
+        // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors
+        let usages: &[ast::PathExpr] = &*usages;
+        let expr: &ast::Expr = expr;
+        match usages {
             // inline single use closure arguments
             [usage]
                 if matches!(expr, ast::Expr::ClosureExpr(_))
@@ -371,21 +363,19 @@ fn inline(
             {
                 cov_mark::hit!(inline_call_inline_closure);
                 let expr = make::expr_paren(expr.clone());
-                ted::replace(usage.syntax(), expr.syntax().clone_for_update());
+                inline_direct(usage, &expr);
             }
             // inline single use literals
             [usage] if matches!(expr, ast::Expr::Literal(_)) => {
                 cov_mark::hit!(inline_call_inline_literal);
-                ted::replace(usage.syntax(), expr.syntax().clone_for_update());
+                inline_direct(usage, &expr);
             }
             // inline direct local arguments
-            [_, ..] if expr_is_name_ref => {
+            [_, ..] if expr_as_name_ref(&expr).is_some() => {
                 cov_mark::hit!(inline_call_inline_locals);
-                usages.into_iter().for_each(|usage| {
-                    ted::replace(usage.syntax(), &expr.syntax().clone_for_update());
-                });
+                usages.into_iter().for_each(|usage| inline_direct(usage, &expr));
             }
-            // cant inline, emit a let statement
+            // can't inline, emit a let statement
             _ => {
                 let ty =
                     sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone());
@@ -410,17 +400,33 @@ fn inline(
     }
 
     let original_indentation = match node {
-        CallExprNode::Call(it) => it.indent_level(),
-        CallExprNode::MethodCallExpr(it) => it.indent_level(),
+        ast::CallableExpr::Call(it) => it.indent_level(),
+        ast::CallableExpr::MethodCall(it) => it.indent_level(),
     };
     body.reindent_to(original_indentation);
 
     match body.tail_expr() {
         Some(expr) if body.statements().next().is_none() => expr,
-        _ => ast::Expr::BlockExpr(body),
+        _ => match node
+            .syntax()
+            .parent()
+            .and_then(ast::BinExpr::cast)
+            .and_then(|bin_expr| bin_expr.lhs())
+        {
+            Some(lhs) if lhs.syntax() == node.syntax() => {
+                make::expr_paren(ast::Expr::BlockExpr(body)).clone_for_update()
+            }
+            _ => ast::Expr::BlockExpr(body),
+        },
     }
 }
 
+fn path_expr_as_record_field(usage: &PathExpr) -> Option<ast::RecordExprField> {
+    let path = usage.path()?;
+    let name_ref = path.as_single_name_ref()?;
+    ast::RecordExprField::for_name_ref(&name_ref)
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -1025,4 +1031,115 @@ fn foo() {
 "#,
         );
     }
+
+    #[test]
+    fn inline_call_field_shorthand() {
+        cov_mark::check!(inline_call_inline_direct_field);
+        check_assist(
+            inline_call,
+            r#"
+struct Foo {
+    field: u32,
+    field1: u32,
+    field2: u32,
+    field3: u32,
+}
+fn foo(field: u32, field1: u32, val2: u32, val3: u32) -> Foo {
+    Foo {
+        field,
+        field1,
+        field2: val2,
+        field3: val3,
+    }
+}
+fn main() {
+    let bar = 0;
+    let baz = 0;
+    foo$0(bar, 0, baz, 0);
+}
+"#,
+            r#"
+struct Foo {
+    field: u32,
+    field1: u32,
+    field2: u32,
+    field3: u32,
+}
+fn foo(field: u32, field1: u32, val2: u32, val3: u32) -> Foo {
+    Foo {
+        field,
+        field1,
+        field2: val2,
+        field3: val3,
+    }
+}
+fn main() {
+    let bar = 0;
+    let baz = 0;
+    Foo {
+            field: bar,
+            field1: 0,
+            field2: baz,
+            field3: 0,
+        };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn inline_callers_wrapped_in_parentheses() {
+        check_assist(
+            inline_into_callers,
+            r#"
+fn foo$0() -> u32 {
+    let x = 0;
+    x
+}
+fn bar() -> u32 {
+    foo() + foo()
+}
+"#,
+            r#"
+
+fn bar() -> u32 {
+    ({
+        let x = 0;
+        x
+    }) + {
+        let x = 0;
+        x
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn inline_call_wrapped_in_parentheses() {
+        check_assist(
+            inline_call,
+            r#"
+fn foo() -> u32 {
+    let x = 0;
+    x
+}
+fn bar() -> u32 {
+    foo$0() + foo()
+}
+"#,
+            r#"
+fn foo() -> u32 {
+    let x = 0;
+    x
+}
+fn bar() -> u32 {
+    ({
+        let x = 0;
+        x
+    }) + foo()
+}
+"#,
+        )
+    }
 }