]> git.lizzy.rs Git - rust.git/commitdiff
allow local variables to be used after extracted body
authorVladyslav Katasonov <cpud47@gmail.com>
Wed, 3 Feb 2021 17:31:12 +0000 (20:31 +0300)
committerVladyslav Katasonov <cpud47@gmail.com>
Wed, 3 Feb 2021 18:11:12 +0000 (21:11 +0300)
when variable is defined inside extracted body
export this variable to original scope via return value(s)

crates/assists/src/handlers/extract_function.rs

index 958199e5e1656d89944f7cb5cceb25f81bcdefbe..c5e6ec7331bd95e48c517fa69dab272fa8a8e7c1 100644 (file)
@@ -1,7 +1,10 @@
 use either::Either;
 use hir::{HirDisplay, Local};
-use ide_db::defs::{Definition, NameRefClass};
-use rustc_hash::FxHashSet;
+use ide_db::{
+    defs::{Definition, NameRefClass},
+    search::SearchScope,
+};
+use itertools::Itertools;
 use stdx::format_to;
 use syntax::{
     ast::{
@@ -81,9 +84,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
     }
     let body = body?;
 
+    let vars_used_in_body = vars_used_in_body(&body, &ctx);
     let mut self_param = None;
-    let mut param_pats: Vec<_> = local_variables(&body, &ctx)
-        .into_iter()
+    let param_pats: Vec<_> = vars_used_in_body
+        .iter()
         .map(|node| node.source(ctx.db()))
         .filter(|src| {
             src.file_id.original_file(ctx.db()) == ctx.frange.file_id
@@ -98,12 +102,27 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
             }
         })
         .collect();
-    deduplicate_params(&mut param_pats);
 
     let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
     let insert_after = body.scope_for_fn_insertion(anchor)?;
     let module = ctx.sema.scope(&insert_after).module()?;
 
+    let vars_defined_in_body = vars_defined_in_body(&body, ctx);
+
+    let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body
+        .iter()
+        .copied()
+        .filter(|node| {
+            let usages = Definition::Local(*node)
+                .usages(&ctx.sema)
+                .in_scope(SearchScope::single_file(ctx.frange.file_id))
+                .all();
+            let mut usages = usages.iter().flat_map(|(_, rs)| rs.iter());
+
+            usages.any(|reference| body.preceedes_range(reference.range))
+        })
+        .collect();
+
     let params = param_pats
         .into_iter()
         .map(|pat| {
@@ -119,20 +138,18 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
         })
         .collect::<Vec<_>>();
 
-    let self_param =
-        if let Some(self_param) = self_param { Some(self_param.to_string()) } else { None };
-
     let expr = body.tail_expr();
     let ret_ty = match expr {
-        Some(expr) => {
-            // FIXME: can we do assist when type is unknown?
-            //        We can insert something like `-> ()`
-            let ty = ctx.sema.type_of_expr(&expr)?;
-            Some(ty.display_source_code(ctx.db(), module.into()).ok()?)
-        }
+        Some(expr) => Some(ctx.sema.type_of_expr(&expr)?),
         None => None,
     };
 
+    let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit());
+    if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) {
+        // We should not have variables that outlive body if we have expression block
+        return None;
+    }
+
     let target_range = match &body {
         FunctionBody::Expr(expr) => expr.syntax().text_range(),
         FunctionBody::Span { .. } => ctx.frange.range,
@@ -143,21 +160,46 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
         "Extract into function",
         target_range,
         move |builder| {
-            let fun = Function { name: "fun_name".to_string(), self_param, params, ret_ty, body };
+            let fun = Function {
+                name: "fun_name".to_string(),
+                self_param,
+                params,
+                ret_ty,
+                body,
+                vars_in_body_used_afterwards,
+            };
 
-            builder.replace(target_range, format_replacement(&fun));
+            builder.replace(target_range, format_replacement(ctx, &fun));
 
             let indent = IndentLevel::from_node(&insert_after);
 
-            let fn_def = format_function(&fun, indent);
+            let fn_def = format_function(ctx, module, &fun, indent);
             let insert_offset = insert_after.text_range().end();
             builder.insert(insert_offset, fn_def);
         },
     )
 }
 
-fn format_replacement(fun: &Function) -> String {
+fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
     let mut buf = String::new();
+
+    match fun.vars_in_body_used_afterwards.len() {
+        0 => {}
+        1 => format_to!(
+            buf,
+            "let {} = ",
+            fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()
+        ),
+        _ => {
+            buf.push_str("let (");
+            format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap());
+            for local in fun.vars_in_body_used_afterwards.iter().skip(1) {
+                format_to!(buf, ", {}", local.name(ctx.db()).unwrap());
+            }
+            buf.push_str(") = ");
+        }
+    }
+
     if fun.self_param.is_some() {
         format_to!(buf, "self.");
     }
@@ -182,16 +224,17 @@ fn format_replacement(fun: &Function) -> String {
 
 struct Function {
     name: String,
-    self_param: Option<String>,
+    self_param: Option<ast::SelfParam>,
     params: Vec<Param>,
-    ret_ty: Option<String>,
+    ret_ty: Option<hir::Type>,
     body: FunctionBody,
+    vars_in_body_used_afterwards: Vec<Local>,
 }
 
 impl Function {
     fn has_unit_ret(&self) -> bool {
         match &self.ret_ty {
-            Some(ty) => ty == "()",
+            Some(ty) => ty.is_unit(),
             None => true,
         }
     }
@@ -203,7 +246,12 @@ struct Param {
     ty: String,
 }
 
-fn format_function(fun: &Function, indent: IndentLevel) -> String {
+fn format_function(
+    ctx: &AssistContext,
+    module: hir::Module,
+    fun: &Function,
+    indent: IndentLevel,
+) -> String {
     let mut fn_def = String::new();
     format_to!(fn_def, "\n\n{}fn $0{}(", indent, fun.name);
     {
@@ -221,10 +269,24 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String {
     format_to!(fn_def, ")");
     if !fun.has_unit_ret() {
         if let Some(ty) = &fun.ret_ty {
-            format_to!(fn_def, " -> {}", ty);
+            format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
+        }
+    } else {
+        match fun.vars_in_body_used_afterwards.as_slice() {
+            [] => {}
+            [var] => {
+                format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
+            }
+            [v0, vs @ ..] => {
+                format_to!(fn_def, " -> ({}", format_type(&v0.ty(ctx.db()), ctx, module));
+                for var in vs {
+                    format_to!(fn_def, ", {}", format_type(&var.ty(ctx.db()), ctx, module));
+                }
+                fn_def.push(')');
+            }
         }
     }
-    format_to!(fn_def, " {{");
+    fn_def.push_str(" {");
 
     match &fun.body {
         FunctionBody::Expr(expr) => {
@@ -243,11 +305,28 @@ fn format_function(fun: &Function, indent: IndentLevel) -> String {
             }
         }
     }
+
+    match fun.vars_in_body_used_afterwards.as_slice() {
+        [] => {}
+        [var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
+        [v0, vs @ ..] => {
+            format_to!(fn_def, "{}({}", indent + 1, v0.name(ctx.db()).unwrap());
+            for var in vs {
+                format_to!(fn_def, ", {}", var.name(ctx.db()).unwrap());
+            }
+            fn_def.push_str(")\n");
+        }
+    }
+
     format_to!(fn_def, "{}}}", indent);
 
     fn_def
 }
 
+fn format_type(ty: &hir::Type, ctx: &AssistContext, module: hir::Module) -> String {
+    ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
+}
+
 #[derive(Debug)]
 enum FunctionBody {
     Expr(ast::Expr),
@@ -339,18 +418,26 @@ fn descendants(&self) -> impl Iterator<Item = SyntaxNode> + '_ {
         }
     }
 
-    fn contains_node(&self, node: &SyntaxNode) -> bool {
-        fn is_node(body: &FunctionBody, n: &SyntaxNode) -> bool {
-            match body {
-                FunctionBody::Expr(expr) => n == expr.syntax(),
-                FunctionBody::Span { elements, .. } => {
-                    // FIXME: can it be quadratic?
-                    elements.iter().filter_map(SyntaxElement::as_node).any(|e| e == n)
-                }
-            }
+    fn text_range(&self) -> TextRange {
+        match self {
+            FunctionBody::Expr(expr) => expr.syntax().text_range(),
+            FunctionBody::Span { elements, .. } => TextRange::new(
+                elements.first().unwrap().text_range().start(),
+                elements.last().unwrap().text_range().end(),
+            ),
         }
+    }
+
+    fn contains_range(&self, range: TextRange) -> bool {
+        self.text_range().contains_range(range)
+    }
 
-        node.ancestors().any(|a| is_node(self, &a))
+    fn preceedes_range(&self, range: TextRange) -> bool {
+        self.text_range().end() <= range.start()
+    }
+
+    fn contains_node(&self, node: &SyntaxNode) -> bool {
+        self.contains_range(node.text_range())
     }
 }
 
@@ -383,11 +470,6 @@ fn scope_for_fn_insertion(node: &SyntaxNode, anchor: Anchor) -> Option<SyntaxNod
     last_ancestor
 }
 
-fn deduplicate_params(params: &mut Vec<ast::IdentPat>) {
-    let mut seen_params = FxHashSet::default();
-    params.retain(|p| seen_params.insert(p.clone()));
-}
-
 fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
     match value {
         Either::Left(pat) => pat.syntax(),
@@ -395,8 +477,8 @@ fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
     }
 }
 
-/// Returns a vector of local variables that are refferenced in `body`
-fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
+/// Returns a vector of local variables that are referenced in `body`
+fn vars_used_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
     body.descendants()
         .filter_map(ast::NameRef::cast)
         .filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref))
@@ -405,6 +487,16 @@ fn local_variables(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
             Definition::Local(local) => Some(local),
             _ => None,
         })
+        .unique()
+        .collect()
+}
+
+/// Returns a vector of local variables that are defined in `body`
+fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
+    body.descendants()
+        .filter_map(ast::IdentPat::cast)
+        .filter_map(|let_stmt| ctx.sema.to_def(&let_stmt))
+        .unique()
         .collect()
 }
 
@@ -970,6 +1062,56 @@ fn foo(&mut self) -> i32 {
     fn $0fun_name(&self) -> i32 {
         1+self.f
     }
+}",
+        );
+    }
+
+    #[test]
+    fn variable_defined_inside_and_used_after_no_ret() {
+        check_assist(
+            extract_function,
+            r"
+fn foo() {
+    let n = 1;
+    $0let k = n * n;$0
+    let m = k + 1;
+}",
+            r"
+fn foo() {
+    let n = 1;
+    let k = fun_name(n);
+    let m = k + 1;
+}
+
+fn $0fun_name(n: i32) -> i32 {
+    let k = n * n;
+    k
+}",
+        );
+    }
+
+    #[test]
+    fn two_variables_defined_inside_and_used_after_no_ret() {
+        check_assist(
+            extract_function,
+            r"
+fn foo() {
+    let n = 1;
+    $0let k = n * n;
+    let m = k + 2;$0
+    let h = k + m;
+}",
+            r"
+fn foo() {
+    let n = 1;
+    let (k, m) = fun_name(n);
+    let h = k + m;
+}
+
+fn $0fun_name(n: i32) -> (i32, i32) {
+    let k = n * n;
+    let m = k + 2;
+    (k, m)
 }",
         );
     }