]> git.lizzy.rs Git - rust.git/commitdiff
allow `&mut param` when extracting function
authorVladyslav Katasonov <cpud47@gmail.com>
Wed, 3 Feb 2021 21:27:31 +0000 (00:27 +0300)
committerVladyslav Katasonov <cpud47@gmail.com>
Wed, 3 Feb 2021 21:27:31 +0000 (00:27 +0300)
Recognise &mut as variable modification.
This allows extracting functions with
`&mut var` with `var` being in outer scope

crates/assists/src/handlers/extract_function.rs

index ffa8bd77dc60593aba9ae7e490ea0f2f31af63b7..a4b23d756234867804f8041cedbfedf8140af73e 100644 (file)
@@ -2,7 +2,7 @@
 use hir::{HirDisplay, Local};
 use ide_db::{
     defs::{Definition, NameRefClass},
-    search::{ReferenceAccess, SearchScope},
+    search::{FileReference, ReferenceAccess, SearchScope},
 };
 use itertools::Itertools;
 use stdx::format_to;
@@ -15,7 +15,7 @@
     },
     Direction, SyntaxElement,
     SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR},
-    SyntaxNode, TextRange, T,
+    SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, T,
 };
 use test_utils::mark;
 
@@ -140,7 +140,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
                 .iter()
                 .flat_map(|(_, rs)| rs.iter())
                 .filter(|reference| body.contains_range(reference.range))
-                .any(|reference| reference.access == Some(ReferenceAccess::Write));
+                .any(|reference| {
+                    if reference.access == Some(ReferenceAccess::Write) {
+                        return true;
+                    }
+
+                    let path = path_at_offset(&body, reference);
+                    if is_mut_ref_expr(path.as_ref()).unwrap_or(false) {
+                        return true;
+                    }
+
+                    false
+                });
 
             Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true }
         })
@@ -405,6 +416,19 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri
     ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
 }
 
+fn path_at_offset(body: &FunctionBody, reference: &FileReference) -> Option<ast::Expr> {
+    let var = body.token_at_offset(reference.range.start()).right_biased()?;
+    let path = var.ancestors().find_map(ast::Expr::cast)?;
+    stdx::always!(matches!(path, ast::Expr::PathExpr(_)));
+    Some(path)
+}
+
+fn is_mut_ref_expr(path: Option<&ast::Expr>) -> Option<bool> {
+    let path = path?;
+    let ref_expr = path.syntax().parent().and_then(ast::RefExpr::cast)?;
+    Some(ref_expr.mut_token().is_some())
+}
+
 fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode {
     let mut rewriter = SyntaxRewriter::default();
     for param in params {
@@ -551,6 +575,38 @@ fn descendants(&self) -> impl Iterator<Item = SyntaxNode> + '_ {
         }
     }
 
+    fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset<SyntaxToken> {
+        match self {
+            FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset),
+            FunctionBody::Span { elements, .. } => {
+                stdx::always!(self.text_range().contains(offset));
+                let mut iter = elements
+                    .iter()
+                    .filter(|element| element.text_range().contains_inclusive(offset));
+                let element1 = iter.next().expect("offset does not fall into body");
+                let element2 = iter.next();
+                stdx::always!(iter.next().is_none(), "> 2 tokens at offset");
+                let t1 = match element1 {
+                    syntax::NodeOrToken::Node(node) => node.token_at_offset(offset),
+                    syntax::NodeOrToken::Token(token) => TokenAtOffset::Single(token.clone()),
+                };
+                let t2 = element2.map(|e| match e {
+                    syntax::NodeOrToken::Node(node) => node.token_at_offset(offset),
+                    syntax::NodeOrToken::Token(token) => TokenAtOffset::Single(token.clone()),
+                });
+
+                match t2 {
+                    Some(t2) => match (t1.clone().right_biased(), t2.clone().left_biased()) {
+                        (Some(e1), Some(e2)) => TokenAtOffset::Between(e1, e2),
+                        (Some(_), None) => t1,
+                        (None, _) => t2,
+                    },
+                    None => t1,
+                }
+            }
+        }
+    }
+
     fn text_range(&self) -> TextRange {
         match self {
             FunctionBody::Expr(expr) => expr.syntax().text_range(),
@@ -1403,6 +1459,54 @@ fn foo() {
 
 fn $0fun_name(mut n: i32) {
     n += 1;
+}",
+        );
+    }
+
+    #[test]
+    fn mut_param_because_of_mut_ref() {
+        check_assist(
+            extract_function,
+            r"
+fn foo() {
+    let mut n = 1;
+    $0let v = &mut n;
+    *v += 1;$0
+    let k = n;
+}",
+            r"
+fn foo() {
+    let mut n = 1;
+    fun_name(&mut n);
+    let k = n;
+}
+
+fn $0fun_name(n: &mut i32) {
+    let v = n;
+    *v += 1;
+}",
+        );
+    }
+
+    #[test]
+    fn mut_param_by_value_because_of_mut_ref() {
+        check_assist(
+            extract_function,
+            r"
+fn foo() {
+    let mut n = 1;
+    $0let v = &mut n;
+    *v += 1;$0
+}",
+            r"
+fn foo() {
+    let mut n = 1;
+    fun_name(n);
+}
+
+fn $0fun_name(mut n: i32) {
+    let v = &mut n;
+    *v += 1;
 }",
         );
     }