]> git.lizzy.rs Git - rust.git/commitdiff
allow modifications of vars from outer scope inside extracted function
authorVladyslav Katasonov <cpud47@gmail.com>
Wed, 3 Feb 2021 20:45:03 +0000 (23:45 +0300)
committerVladyslav Katasonov <cpud47@gmail.com>
Wed, 3 Feb 2021 20:45:03 +0000 (23:45 +0300)
It currently allows only directly setting variable.
No `&mut` references or methods.

crates/assists/src/handlers/extract_function.rs
crates/syntax/src/ast/make.rs

index c5e6ec7331bd95e48c517fa69dab272fa8a8e7c1..ffa8bd77dc60593aba9ae7e490ea0f2f31af63b7 100644 (file)
@@ -2,19 +2,20 @@
 use hir::{HirDisplay, Local};
 use ide_db::{
     defs::{Definition, NameRefClass},
-    search::SearchScope,
+    search::{ReferenceAccess, SearchScope},
 };
 use itertools::Itertools;
 use stdx::format_to;
 use syntax::{
+    algo::SyntaxRewriter,
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
-        AstNode, NameOwner,
+        AstNode,
     },
     Direction, SyntaxElement,
     SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR},
-    SyntaxNode, TextRange,
+    SyntaxNode, TextRange, T,
 };
 use test_utils::mark;
 
@@ -88,16 +89,16 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
     let mut self_param = None;
     let param_pats: Vec<_> = vars_used_in_body
         .iter()
-        .map(|node| node.source(ctx.db()))
-        .filter(|src| {
+        .map(|node| (node, node.source(ctx.db())))
+        .filter(|(_, src)| {
             src.file_id.original_file(ctx.db()) == ctx.frange.file_id
                 && !body.contains_node(&either_syntax(&src.value))
         })
-        .filter_map(|src| match src.value {
-            Either::Left(pat) => Some(pat),
+        .filter_map(|(&node, src)| match src.value {
+            Either::Left(_) => Some(node),
             Either::Right(it) => {
                 // we filter self param, as there can only be one
-                self_param = Some(it);
+                self_param = Some((node, it));
                 None
             }
         })
@@ -109,7 +110,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
 
     let vars_defined_in_body = vars_defined_in_body(&body, ctx);
 
-    let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body
+    let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body
         .iter()
         .copied()
         .filter(|node| {
@@ -123,20 +124,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
         })
         .collect();
 
-    let params = param_pats
+    let params: Vec<_> = param_pats
         .into_iter()
-        .map(|pat| {
-            let name = pat.name().unwrap().to_string();
-
-            let ty = ctx
-                .sema
-                .type_of_pat(&pat.into())
-                .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok())
-                .unwrap_or_else(|| "()".to_string());
+        .map(|node| {
+            let usages = Definition::Local(node)
+                .usages(&ctx.sema)
+                .in_scope(SearchScope::single_file(ctx.frange.file_id))
+                .all();
 
-            Param { name, ty }
+            let has_usages_afterwards = usages
+                .iter()
+                .flat_map(|(_, rs)| rs.iter())
+                .any(|reference| body.preceedes_range(reference.range));
+            let has_mut_inside_body = usages
+                .iter()
+                .flat_map(|(_, rs)| rs.iter())
+                .filter(|reference| body.contains_range(reference.range))
+                .any(|reference| reference.access == Some(ReferenceAccess::Write));
+
+            Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true }
         })
-        .collect::<Vec<_>>();
+        .collect();
 
     let expr = body.tail_expr();
     let ret_ty = match expr {
@@ -145,7 +153,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
     };
 
     let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit());
-    if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) {
+    if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) {
         // We should not have variables that outlive body if we have expression block
         return None;
     }
@@ -162,11 +170,11 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
         move |builder| {
             let fun = Function {
                 name: "fun_name".to_string(),
-                self_param,
+                self_param: self_param.map(|(_, pat)| pat),
                 params,
                 ret_ty,
                 body,
-                vars_in_body_used_afterwards,
+                vars_defined_in_body_and_outlive,
             };
 
             builder.replace(target_range, format_replacement(ctx, &fun));
@@ -183,17 +191,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
 fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
     let mut buf = String::new();
 
-    match fun.vars_in_body_used_afterwards.len() {
-        0 => {}
-        1 => format_to!(
-            buf,
-            "let {} = ",
-            fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()
-        ),
-        _ => {
+    match fun.vars_defined_in_body_and_outlive.as_slice() {
+        [] => {}
+        [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()),
+        [v0, vs @ ..] => {
             buf.push_str("let (");
-            format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap());
-            for local in fun.vars_in_body_used_afterwards.iter().skip(1) {
+            format_to!(buf, "{}", v0.name(ctx.db()).unwrap());
+            for local in vs {
                 format_to!(buf, ", {}", local.name(ctx.db()).unwrap());
             }
             buf.push_str(") = ");
@@ -207,10 +211,10 @@ fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
     {
         let mut it = fun.params.iter();
         if let Some(param) = it.next() {
-            format_to!(buf, "{}", param.name);
+            format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
         }
         for param in it {
-            format_to!(buf, ", {}", param.name);
+            format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
         }
     }
     format_to!(buf, ")");
@@ -228,7 +232,7 @@ struct Function {
     params: Vec<Param>,
     ret_ty: Option<hir::Type>,
     body: FunctionBody,
-    vars_in_body_used_afterwards: Vec<Local>,
+    vars_defined_in_body_and_outlive: Vec<Local>,
 }
 
 impl Function {
@@ -242,8 +246,60 @@ fn has_unit_ret(&self) -> bool {
 
 #[derive(Debug)]
 struct Param {
-    name: String,
-    ty: String,
+    node: Local,
+    has_usages_afterwards: bool,
+    has_mut_inside_body: bool,
+    is_copy: bool,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum ParamKind {
+    Value,
+    MutValue,
+    SharedRef,
+    MutRef,
+}
+
+impl ParamKind {
+    fn is_ref(&self) -> bool {
+        matches!(self, ParamKind::SharedRef | ParamKind::MutRef)
+    }
+}
+
+impl Param {
+    fn kind(&self) -> ParamKind {
+        match (self.has_usages_afterwards, self.has_mut_inside_body, self.is_copy) {
+            (true, true, _) => ParamKind::MutRef,
+            (true, false, false) => ParamKind::SharedRef,
+            (false, true, _) => ParamKind::MutValue,
+            (true, false, true) | (false, false, _) => ParamKind::Value,
+        }
+    }
+
+    fn value_prefix(&self) -> &'static str {
+        match self.kind() {
+            ParamKind::Value => "",
+            ParamKind::MutValue => "",
+            ParamKind::SharedRef => "&",
+            ParamKind::MutRef => "&mut ",
+        }
+    }
+
+    fn type_prefix(&self) -> &'static str {
+        match self.kind() {
+            ParamKind::Value => "",
+            ParamKind::MutValue => "",
+            ParamKind::SharedRef => "&",
+            ParamKind::MutRef => "&mut ",
+        }
+    }
+
+    fn mut_pattern(&self) -> &'static str {
+        match self.kind() {
+            ParamKind::MutValue => "mut ",
+            _ => "",
+        }
+    }
 }
 
 fn format_function(
@@ -259,10 +315,24 @@ fn format_function(
         if let Some(self_param) = &fun.self_param {
             format_to!(fn_def, "{}", self_param);
         } else if let Some(param) = it.next() {
-            format_to!(fn_def, "{}: {}", param.name, param.ty);
+            format_to!(
+                fn_def,
+                "{}{}: {}{}",
+                param.mut_pattern(),
+                param.node.name(ctx.db()).unwrap(),
+                param.type_prefix(),
+                format_type(&param.node.ty(ctx.db()), ctx, module)
+            );
         }
         for param in it {
-            format_to!(fn_def, ", {}: {}", param.name, param.ty);
+            format_to!(
+                fn_def,
+                ", {}{}: {}{}",
+                param.mut_pattern(),
+                param.node.name(ctx.db()).unwrap(),
+                param.type_prefix(),
+                format_type(&param.node.ty(ctx.db()), ctx, module)
+            );
         }
     }
 
@@ -272,7 +342,7 @@ fn format_function(
             format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
         }
     } else {
-        match fun.vars_in_body_used_afterwards.as_slice() {
+        match fun.vars_defined_in_body_and_outlive.as_slice() {
             [] => {}
             [var] => {
                 format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
@@ -292,13 +362,21 @@ fn format_function(
         FunctionBody::Expr(expr) => {
             fn_def.push('\n');
             let expr = expr.indent(indent);
-            format_to!(fn_def, "{}{}", indent + 1, expr.syntax());
+            let expr = fix_param_usages(ctx, &fun.params, expr.syntax());
+            format_to!(fn_def, "{}{}", indent + 1, expr);
             fn_def.push('\n');
         }
         FunctionBody::Span { elements, leading_indent } => {
             format_to!(fn_def, "{}", leading_indent);
-            for e in elements {
-                format_to!(fn_def, "{}", e);
+            for element in elements {
+                match element {
+                    syntax::NodeOrToken::Node(node) => {
+                        format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node));
+                    }
+                    syntax::NodeOrToken::Token(token) => {
+                        format_to!(fn_def, "{}", token);
+                    }
+                }
             }
             if !fn_def.ends_with('\n') {
                 fn_def.push('\n');
@@ -306,7 +384,7 @@ fn format_function(
         }
     }
 
-    match fun.vars_in_body_used_afterwards.as_slice() {
+    match fun.vars_defined_in_body_and_outlive.as_slice() {
         [] => {}
         [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
         [v0, vs @ ..] => {
@@ -327,6 +405,61 @@ fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> Stri
     ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
 }
 
+fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode {
+    let mut rewriter = SyntaxRewriter::default();
+    for param in params {
+        if !param.kind().is_ref() {
+            continue;
+        }
+
+        let usages = Definition::Local(param.node)
+            .usages(&ctx.sema)
+            .in_scope(SearchScope::single_file(ctx.frange.file_id))
+            .all();
+        let usages = usages
+            .iter()
+            .flat_map(|(_, rs)| rs.iter())
+            .filter(|reference| syntax.text_range().contains_range(reference.range));
+        for reference in usages {
+            let token = match syntax.token_at_offset(reference.range.start()).right_biased() {
+                Some(a) => a,
+                None => {
+                    stdx::never!(false, "cannot find token at variable usage: {:?}", reference);
+                    continue;
+                }
+            };
+            let path = match token.ancestors().find_map(ast::Expr::cast) {
+                Some(n) => n,
+                None => {
+                    stdx::never!(false, "cannot find path parent of variable usage: {:?}", token);
+                    continue;
+                }
+            };
+            stdx::always!(matches!(path, ast::Expr::PathExpr(_)));
+            match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) {
+                Some(ast::Expr::MethodCallExpr(_)) => {
+                    // do nothing
+                }
+                Some(ast::Expr::RefExpr(node))
+                    if param.kind() == ParamKind::MutRef && node.mut_token().is_some() =>
+                {
+                    rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
+                }
+                Some(ast::Expr::RefExpr(node))
+                    if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() =>
+                {
+                    rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
+                }
+                Some(_) | None => {
+                    rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone()));
+                }
+            };
+        }
+    }
+
+    rewriter.rewrite(syntax)
+}
+
 #[derive(Debug)]
 enum FunctionBody {
     Expr(ast::Expr),
@@ -1112,6 +1245,164 @@ fn $0fun_name(n: i32) -> (i32, i32) {
     let k = n * n;
     let m = k + 2;
     (k, m)
+}",
+        );
+    }
+
+    #[test]
+    fn mut_var_from_outer_scope() {
+        check_assist(
+            extract_function,
+            r"
+fn foo() {
+    let mut n = 1;
+    $0n += 1;$0
+    let m = n + 1;
+}",
+            r"
+fn foo() {
+    let mut n = 1;
+    fun_name(&mut n);
+    let m = n + 1;
+}
+
+fn $0fun_name(n: &mut i32) {
+    *n += 1;
+}",
+        );
+    }
+
+    #[test]
+    fn mut_param_many_usages_stmt() {
+        check_assist(
+            extract_function,
+            r"
+fn bar(k: i32) {}
+trait I: Copy {
+    fn succ(&self) -> Self;
+    fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+    fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+    let mut n = 1;
+    $0n += n;
+    bar(n);
+    bar(n+1);
+    bar(n*n);
+    bar(&n);
+    n.inc();
+    let v = &mut n;
+    *v = v.succ();
+    n.succ();$0
+    let m = n + 1;
+}",
+            r"
+fn bar(k: i32) {}
+trait I: Copy {
+    fn succ(&self) -> Self;
+    fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+    fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+    let mut n = 1;
+    fun_name(&mut n);
+    let m = n + 1;
+}
+
+fn $0fun_name(n: &mut i32) {
+    *n += *n;
+    bar(*n);
+    bar(*n+1);
+    bar(*n**n);
+    bar(&*n);
+    n.inc();
+    let v = n;
+    *v = v.succ();
+    n.succ();
+}",
+        );
+    }
+
+    #[test]
+    fn mut_param_many_usages_expr() {
+        check_assist(
+            extract_function,
+            r"
+fn bar(k: i32) {}
+trait I: Copy {
+    fn succ(&self) -> Self;
+    fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+    fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+    let mut n = 1;
+    $0{
+        n += n;
+        bar(n);
+        bar(n+1);
+        bar(n*n);
+        bar(&n);
+        n.inc();
+        let v = &mut n;
+        *v = v.succ();
+        n.succ();
+    }$0
+    let m = n + 1;
+}",
+            r"
+fn bar(k: i32) {}
+trait I: Copy {
+    fn succ(&self) -> Self;
+    fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+    fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+    let mut n = 1;
+    fun_name(&mut n);
+    let m = n + 1;
+}
+
+fn $0fun_name(n: &mut i32) {
+    {
+        *n += *n;
+        bar(*n);
+        bar(*n+1);
+        bar(*n**n);
+        bar(&*n);
+        n.inc();
+        let v = n;
+        *v = v.succ();
+        n.succ();
+    }
+}",
+        );
+    }
+
+    #[test]
+    fn mut_param_by_value() {
+        check_assist(
+            extract_function,
+            r"
+fn foo() {
+    let mut n = 1;
+    $0n += 1;$0
+}",
+            r"
+fn foo() {
+    let mut n = 1;
+    fun_name(n);
+}
+
+fn $0fun_name(mut n: i32) {
+    n += 1;
 }",
         );
     }
index b755c969288043541129b0a66ba87cdc0058f5da..1da5a125ed376cdbb4df813fff3176b70f70a7b0 100644 (file)
@@ -487,7 +487,7 @@ pub mod tokens {
     use crate::{ast, AstNode, Parse, SourceFile, SyntaxKind::*, SyntaxToken};
 
     pub(super) static SOURCE_FILE: Lazy<Parse<SourceFile>> =
-        Lazy::new(|| SourceFile::parse("const C: <()>::Item = (1 != 1, 2 == 2, !true)\n;\n\n"));
+        Lazy::new(|| SourceFile::parse("const C: <()>::Item = (1 != 1, 2 == 2, !true, *p)\n;\n\n"));
 
     pub fn single_space() -> SyntaxToken {
         SOURCE_FILE