]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/extract_function.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / extract_function.rs
index 7ffb5728cc5f2f23942bd73314a3dd81d288fdb7..21cfc76ac9bf6036d100c489218ca0deb6c171ca 100644 (file)
@@ -1,4 +1,4 @@
-use std::{hash::BuildHasherDefault, iter};
+use std::iter;
 
 use ast::make;
 use either::Either;
         FamousDefs,
     },
     search::{FileReference, ReferenceCategory, SearchScope},
-    RootDatabase,
+    FxIndexSet, RootDatabase,
 };
 use itertools::Itertools;
-use rustc_hash::FxHasher;
 use stdx::format_to;
 use syntax::{
     ast::{
     AssistId,
 };
 
-type FxIndexSet<T> = indexmap::IndexSet<T, BuildHasherDefault<FxHasher>>;
-
 // Assist: extract_function
 //
-// Extracts selected statements into new function.
+// Extracts selected statements and comments into new function.
 //
 // ```
 // fn main() {
 //     let n = 1;
 //     $0let m = n + 2;
+//     // calculate
 //     let k = m + n;$0
 //     let g = 3;
 // }
@@ -57,6 +55,7 @@
 //
 // fn $0fun_name(n: i32) {
 //     let m = n + 2;
+//     // calculate
 //     let k = m + n;
 // }
 // ```
@@ -76,6 +75,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
         syntax::NodeOrToken::Node(n) => n,
         syntax::NodeOrToken::Token(t) => t.parent()?,
     };
+
     let body = extraction_target(&node, range)?;
     let container_info = body.analyze_container(&ctx.sema)?;
 
@@ -91,7 +91,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
 
     let target_range = body.text_range();
 
-    let scope = ImportScope::find_insert_use_container_with_macros(&node, &ctx.sema)?;
+    let scope = ImportScope::find_insert_use_container(&node, &ctx.sema)?;
 
     acc.add(
         AssistId("extract_function", crate::AssistKind::RefactorExtract),
@@ -479,11 +479,14 @@ fn from_expr(expr: ast::Expr) -> Option<Self> {
     }
 
     fn from_range(parent: ast::StmtList, selected: TextRange) -> FunctionBody {
-        let mut text_range = parent
-            .statements()
-            .map(|stmt| stmt.syntax().text_range())
-            .filter(|&stmt| selected.intersect(stmt).filter(|it| !it.is_empty()).is_some())
-            .fold1(|acc, stmt| acc.cover(stmt));
+        let full_body = parent.syntax().children_with_tokens();
+
+        let mut text_range = full_body
+            .filter(|it| ast::Stmt::can_cast(it.kind()) || it.kind() == COMMENT)
+            .map(|element| element.text_range())
+            .filter(|&range| selected.intersect(range).filter(|it| !it.is_empty()).is_some())
+            .reduce(|acc, stmt| acc.cover(stmt));
+
         if let Some(tail_range) = parent
             .tail_expr()
             .map(|it| it.syntax().text_range())
@@ -881,7 +884,7 @@ fn extracted_function_params(
                 // We can move the value into the function call if it's not used after the call,
                 // if the var is not used but defined outside a loop we are extracting from we can't move it either
                 // as the function will reuse it in the next iteration.
-                let move_local = !has_usages && defined_outside_parent_loop;
+                let move_local = (!has_usages && defined_outside_parent_loop) || ty.is_reference();
                 Param { var, ty, move_local, requires_mut, is_copy }
             })
             .collect()
@@ -1216,28 +1219,26 @@ fn make_call_expr(&self, call_expr: ast::Expr) -> ast::Expr {
                 let stmt = make::expr_stmt(action);
                 let block = make::block_expr(iter::once(stmt.into()), None);
                 let controlflow_break_path = make::path_from_text("ControlFlow::Break");
-                let condition = make::condition(
+                let condition = make::expr_let(
+                    make::tuple_struct_pat(
+                        controlflow_break_path,
+                        iter::once(make::wildcard_pat().into()),
+                    )
+                    .into(),
                     call_expr,
-                    Some(
-                        make::tuple_struct_pat(
-                            controlflow_break_path,
-                            iter::once(make::wildcard_pat().into()),
-                        )
-                        .into(),
-                    ),
                 );
-                make::expr_if(condition, block, None)
+                make::expr_if(condition.into(), block, None)
             }
             FlowHandler::IfOption { action } => {
                 let path = make::ext::ident_path("Some");
                 let value_pat = make::ext::simple_ident_pat(make::name("value"));
                 let pattern = make::tuple_struct_pat(path, iter::once(value_pat.into()));
-                let cond = make::condition(call_expr, Some(pattern.into()));
+                let cond = make::expr_let(pattern.into(), call_expr);
                 let value = make::expr_path(make::ext::ident_path("value"));
                 let action_expr = action.make_result_handler(Some(value));
                 let action_stmt = make::expr_stmt(action_expr);
                 let then = make::block_expr(iter::once(action_stmt.into()), None);
-                make::expr_if(cond, then, None)
+                make::expr_if(cond.into(), then, None)
             }
             FlowHandler::MatchOption { none } => {
                 let some_name = "value";
@@ -1423,6 +1424,7 @@ fn make_body(
     } else {
         FlowHandler::from_ret_ty(fun, &ret_ty)
     };
+
     let block = match &fun.body {
         FunctionBody::Expr(expr) => {
             let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax());
@@ -1444,21 +1446,28 @@ fn make_body(
         FunctionBody::Span { parent, text_range } => {
             let mut elements: Vec<_> = parent
                 .syntax()
-                .children()
+                .children_with_tokens()
                 .filter(|it| text_range.contains_range(it.text_range()))
-                .map(|it| rewrite_body_segment(ctx, &fun.params, &handler, &it))
+                .map(|it| match &it {
+                    syntax::NodeOrToken::Node(n) => syntax::NodeOrToken::Node(
+                        rewrite_body_segment(ctx, &fun.params, &handler, &n),
+                    ),
+                    _ => it,
+                })
                 .collect();
 
-            let mut tail_expr = match elements.pop() {
-                Some(node) => ast::Expr::cast(node.clone()).or_else(|| {
-                    elements.push(node);
-                    None
-                }),
-                None => None,
+            let mut tail_expr = match &elements.last() {
+                Some(syntax::NodeOrToken::Node(node)) if ast::Expr::can_cast(node.kind()) => {
+                    ast::Expr::cast(node.clone())
+                }
+                _ => None,
             };
 
-            if tail_expr.is_none() {
-                match fun.outliving_locals.as_slice() {
+            match tail_expr {
+                Some(_) => {
+                    elements.pop();
+                }
+                None => match fun.outliving_locals.as_slice() {
                     [] => {}
                     [var] => {
                         tail_expr = Some(path_expr_from_local(ctx, var.local));
@@ -1468,22 +1477,27 @@ fn make_body(
                         let expr = make::expr_tuple(exprs);
                         tail_expr = Some(expr);
                     }
-                }
-            }
-
-            let elements = elements.into_iter().filter_map(|node| match ast::Stmt::cast(node) {
-                Some(stmt) => Some(stmt),
-                None => {
-                    stdx::never!("block contains non-statement");
-                    None
-                }
-            });
+                },
+            };
 
             let body_indent = IndentLevel(1);
-            let elements = elements.map(|stmt| stmt.dedent(old_indent).indent(body_indent));
+            let elements = elements
+                .into_iter()
+                .map(|node_or_token| match &node_or_token {
+                    syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) {
+                        Some(stmt) => {
+                            let indented = stmt.dedent(old_indent).indent(body_indent);
+                            let ast_node = indented.syntax().clone_subtree();
+                            syntax::NodeOrToken::Node(ast_node)
+                        }
+                        _ => node_or_token,
+                    },
+                    _ => node_or_token,
+                })
+                .collect::<Vec<SyntaxElement>>();
             let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent));
 
-            make::block_expr(elements, tail_expr)
+            make::hacky_block_expr_with_comments(elements, tail_expr)
         }
     };
 
@@ -4095,13 +4109,34 @@ fn foo() {
 "#,
             r#"
 fn foo() {
-    /**/
     fun_name();
-    /**/
 }
 
 fn $0fun_name() {
+    /**/
+    foo();
     foo();
+    /**/
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_does_not_tear_body_apart() {
+        check_assist(
+            extract_function,
+            r#"
+fn foo() {
+    $0foo();
+}$0
+"#,
+            r#"
+fn foo() {
+    fun_name();
+}
+
+fn $0fun_name() {
     foo();
 }
 "#,
@@ -4335,6 +4370,226 @@ fn foo() {
 fn $0fun_name(a: _) -> _ {
     a
 }
+"#,
+        );
+    }
+
+    #[test]
+    fn reference_mutable_param_with_further_usages() {
+        check_assist(
+            extract_function,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    $0arg.field = 8;$0
+    // Simulating access after the extracted portion
+    arg.field = 16;
+}
+"#,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    fun_name(arg);
+    // Simulating access after the extracted portion
+    arg.field = 16;
+}
+
+fn $0fun_name(arg: &mut Foo) {
+    arg.field = 8;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn reference_mutable_param_without_further_usages() {
+        check_assist(
+            extract_function,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    $0arg.field = 8;$0
+}
+"#,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    fun_name(arg);
+}
+
+fn $0fun_name(arg: &mut Foo) {
+    arg.field = 8;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_at_start() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0// comment here!
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    // comment here!
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_in_between() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;$0
+    let a = 0;
+    // comment here!
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let a = 0;
+    // comment here!
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_at_end() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let x = 0;
+    // comment here!$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let x = 0;
+    // comment here!
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_indented() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let x = 0;
+    while(true) {
+        // comment here!
+    }$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let x = 0;
+    while(true) {
+        // comment here!
+    }
+}
+"#,
+        );
+    }
+
+    // FIXME: we do want to preserve whitespace
+    #[test]
+    fn extract_function_does_not_preserve_whitespace() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let a = 0;
+
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let a = 0;
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_long_form_comment() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0/* a comment */
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    /* a comment */
+    let x = 0;
+}
 "#,
         );
     }