]> git.lizzy.rs Git - rust.git/blobdiff - src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_call.rs
:arrow_up: rust-analyzer
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / inline_call.rs
index 96890ad51a6f941e3399fea772058ab19d653af9..5ac18727c196050c48b507e379ded4a59010d010 100644 (file)
@@ -1,3 +1,5 @@
+use std::collections::BTreeSet;
+
 use ast::make;
 use either::Either;
 use hir::{db::HirDatabase, PathResolution, Semantics, TypeInfo};
@@ -7,6 +9,7 @@
     imports::insert_use::remove_path_if_in_use_stmt,
     path_transform::PathTransform,
     search::{FileReference, SearchScope},
+    source_change::SourceChangeBuilder,
     syntax_helpers::{insert_whitespace_into_node::insert_ws_into, node_ext::expr_as_name_ref},
     RootDatabase,
 };
@@ -100,18 +103,7 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext<'_>) ->
                 builder.edit_file(file_id);
                 let count = refs.len();
                 // The collects are required as we are otherwise iterating while mutating 🙅‍♀️🙅‍♂️
-                let (name_refs, name_refs_use): (Vec<_>, Vec<_>) = refs
-                    .into_iter()
-                    .filter_map(|file_ref| match file_ref.name {
-                        ast::NameLike::NameRef(name_ref) => Some(name_ref),
-                        _ => None,
-                    })
-                    .partition_map(|name_ref| {
-                        match name_ref.syntax().ancestors().find_map(ast::UseTree::cast) {
-                            Some(use_tree) => Either::Right(builder.make_mut(use_tree)),
-                            None => Either::Left(name_ref),
-                        }
-                    });
+                let (name_refs, name_refs_use) = split_refs_and_uses(builder, refs, Some);
                 let call_infos: Vec<_> = name_refs
                     .into_iter()
                     .filter_map(CallInfo::from_name_ref)
@@ -130,11 +122,7 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext<'_>) ->
                     .count();
                 if replaced + name_refs_use.len() == count {
                     // we replaced all usages in this file, so we can remove the imports
-                    name_refs_use.into_iter().for_each(|use_tree| {
-                        if let Some(path) = use_tree.path() {
-                            remove_path_if_in_use_stmt(&path);
-                        }
-                    })
+                    name_refs_use.iter().for_each(remove_path_if_in_use_stmt);
                 } else {
                     remove_def = false;
                 }
@@ -153,6 +141,23 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext<'_>) ->
     )
 }
 
+pub(super) fn split_refs_and_uses<T: ast::AstNode>(
+    builder: &mut SourceChangeBuilder,
+    iter: impl IntoIterator<Item = FileReference>,
+    mut map_ref: impl FnMut(ast::NameRef) -> Option<T>,
+) -> (Vec<T>, Vec<ast::Path>) {
+    iter.into_iter()
+        .filter_map(|file_ref| match file_ref.name {
+            ast::NameLike::NameRef(name_ref) => Some(name_ref),
+            _ => None,
+        })
+        .filter_map(|name_ref| match name_ref.syntax().ancestors().find_map(ast::UseTree::cast) {
+            Some(use_tree) => builder.make_mut(use_tree).path().map(Either::Right),
+            None => map_ref(name_ref).map(Either::Left),
+        })
+        .partition_map(|either| either)
+}
+
 // Assist: inline_call
 //
 // Inlines a function or method body creating a `let` statement per parameter unless the parameter
@@ -187,10 +192,10 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<
                 PathResolution::Def(hir::ModuleDef::Function(f)) => f,
                 _ => return None,
             };
-            (function, format!("Inline `{}`", path))
+            (function, format!("Inline `{path}`"))
         }
         ast::CallableExpr::MethodCall(call) => {
-            (ctx.sema.resolve_method_call(call)?, format!("Inline `{}`", name_ref))
+            (ctx.sema.resolve_method_call(call)?, format!("Inline `{name_ref}`"))
         }
     };
 
@@ -370,8 +375,44 @@ fn inline(
                 })
         }
     }
+
+    let mut func_let_vars: BTreeSet<String> = BTreeSet::new();
+
+    // grab all of the local variable declarations in the function
+    for stmt in fn_body.statements() {
+        if let Some(let_stmt) = ast::LetStmt::cast(stmt.syntax().to_owned()) {
+            for has_token in let_stmt.syntax().children_with_tokens() {
+                if let Some(node) = has_token.as_node() {
+                    if let Some(ident_pat) = ast::IdentPat::cast(node.to_owned()) {
+                        func_let_vars.insert(ident_pat.syntax().text().to_string());
+                    }
+                }
+            }
+        }
+    }
+
     // Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
     for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments).rev() {
+        // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors
+        let usages: &[ast::PathExpr] = &usages;
+        let expr: &ast::Expr = expr;
+
+        let insert_let_stmt = || {
+            let ty = sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone());
+            if let Some(stmt_list) = body.stmt_list() {
+                stmt_list.push_front(
+                    make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(),
+                )
+            }
+        };
+
+        // check if there is a local var in the function that conflicts with parameter
+        // if it does then emit a let statement and continue
+        if func_let_vars.contains(&expr.syntax().text().to_string()) {
+            insert_let_stmt();
+            continue;
+        }
+
         let inline_direct = |usage, replacement: &ast::Expr| {
             if let Some(field) = path_expr_as_record_field(usage) {
                 cov_mark::hit!(inline_call_inline_direct_field);
@@ -380,9 +421,7 @@ fn inline(
                 ted::replace(usage.syntax(), &replacement.syntax().clone_for_update());
             }
         };
-        // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors
-        let usages: &[ast::PathExpr] = &*usages;
-        let expr: &ast::Expr = expr;
+
         match usages {
             // inline single use closure arguments
             [usage]
@@ -405,18 +444,11 @@ fn inline(
             }
             // can't inline, emit a let statement
             _ => {
-                let ty =
-                    sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone());
-                if let Some(stmt_list) = body.stmt_list() {
-                    stmt_list.push_front(
-                        make::let_stmt(pat.clone(), ty, Some(expr.clone()))
-                            .clone_for_update()
-                            .into(),
-                    )
-                }
+                insert_let_stmt();
             }
         }
     }
+
     if let Some(generic_arg_list) = generic_arg_list.clone() {
         if let Some((target, source)) = &sema.scope(node.syntax()).zip(sema.scope(fn_body.syntax()))
         {
@@ -1253,4 +1285,37 @@ fn main() {
 "#,
         )
     }
+
+    #[test]
+    fn local_variable_shadowing_callers_argument() {
+        check_assist(
+            inline_call,
+            r#"
+fn foo(bar: u32, baz: u32) -> u32 {
+    let a = 1;
+    bar * baz * a * 6
+}
+fn main() {
+    let a = 7;
+    let b = 1;
+    let res = foo$0(a, b);
+}
+"#,
+            r#"
+fn foo(bar: u32, baz: u32) -> u32 {
+    let a = 1;
+    bar * baz * a * 6
+}
+fn main() {
+    let a = 7;
+    let b = 1;
+    let res = {
+        let bar = a;
+        let a = 1;
+        bar * b * a * 6
+    };
+}
+"#,
+        );
+    }
 }