]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/inline_call.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / inline_call.rs
index d070eefd9c6db9f1e6196b09f8c189140a598f2f..d88e3fdcd32ae6d51973cb9007fc6b7cd0276f0a 100644 (file)
@@ -59,7 +59,7 @@
 // }
 // ```
 pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
-    let def_file = ctx.frange.file_id;
+    let def_file = ctx.file_id();
     let name = ctx.find_node_at_offset::<ast::Name>()?;
     let ast_func = name.syntax().parent().and_then(ast::Fn::cast)?;
     let func_body = ast_func.body()?;
@@ -69,7 +69,7 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Opt
 
     let params = get_fn_params(ctx.sema.db, function, &param_list)?;
 
-    let usages = Definition::ModuleDef(hir::ModuleDef::Function(function)).usages(&ctx.sema);
+    let usages = Definition::Function(function).usages(&ctx.sema);
     if !usages.at_least_one() {
         return None;
     }
@@ -199,7 +199,7 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
     let param_list = fn_source.value.param_list()?;
 
     let FileRange { file_id, range } = fn_source.syntax().original_file_range(ctx.sema.db);
-    if file_id == ctx.frange.file_id && range.contains(ctx.frange.range.start()) {
+    if file_id == ctx.file_id() && range.contains(ctx.offset()) {
         cov_mark::hit!(inline_call_recursive);
         return None;
     }
@@ -407,7 +407,17 @@ fn inline(
 
     match body.tail_expr() {
         Some(expr) if body.statements().next().is_none() => expr,
-        _ => ast::Expr::BlockExpr(body),
+        _ => match node
+            .syntax()
+            .parent()
+            .and_then(ast::BinExpr::cast)
+            .and_then(|bin_expr| bin_expr.lhs())
+        {
+            Some(lhs) if lhs.syntax() == node.syntax() => {
+                make::expr_paren(ast::Expr::BlockExpr(body)).clone_for_update()
+            }
+            _ => ast::Expr::BlockExpr(body),
+        },
     }
 }
 
@@ -1076,4 +1086,60 @@ fn main() {
 "#,
         );
     }
+
+    #[test]
+    fn inline_callers_wrapped_in_parentheses() {
+        check_assist(
+            inline_into_callers,
+            r#"
+fn foo$0() -> u32 {
+    let x = 0;
+    x
+}
+fn bar() -> u32 {
+    foo() + foo()
+}
+"#,
+            r#"
+
+fn bar() -> u32 {
+    ({
+        let x = 0;
+        x
+    }) + {
+        let x = 0;
+        x
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn inline_call_wrapped_in_parentheses() {
+        check_assist(
+            inline_call,
+            r#"
+fn foo() -> u32 {
+    let x = 0;
+    x
+}
+fn bar() -> u32 {
+    foo$0() + foo()
+}
+"#,
+            r#"
+fn foo() -> u32 {
+    let x = 0;
+    x
+}
+fn bar() -> u32 {
+    ({
+        let x = 0;
+        x
+    }) + foo()
+}
+"#,
+        )
+    }
 }