]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/extract_function.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / extract_function.rs
index f33b455f141aec1c9b0ab0be928257a8918ab217..21cfc76ac9bf6036d100c489218ca0deb6c171ca 100644 (file)
@@ -1,16 +1,20 @@
-use std::{hash::BuildHasherDefault, iter};
+use std::iter;
 
 use ast::make;
 use either::Either;
-use hir::{HirDisplay, InFile, Local, Semantics, TypeInfo};
+use hir::{HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
 use ide_db::{
     defs::{Definition, NameRefClass},
-    helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
+    helpers::{
+        insert_use::{insert_use, ImportScope},
+        mod_path_to_ast,
+        node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
+        FamousDefs,
+    },
     search::{FileReference, ReferenceCategory, SearchScope},
-    RootDatabase,
+    FxIndexSet, RootDatabase,
 };
 use itertools::Itertools;
-use rustc_hash::FxHasher;
 use stdx::format_to;
 use syntax::{
     ast::{
     AssistId,
 };
 
-type FxIndexSet<T> = indexmap::IndexSet<T, BuildHasherDefault<FxHasher>>;
-
 // Assist: extract_function
 //
-// Extracts selected statements into new function.
+// Extracts selected statements and comments into new function.
 //
 // ```
 // fn main() {
 //     let n = 1;
 //     $0let m = n + 2;
+//     // calculate
 //     let k = m + n;$0
 //     let g = 3;
 // }
@@ -52,6 +55,7 @@
 //
 // fn $0fun_name(n: i32) {
 //     let m = n + 2;
+//     // calculate
 //     let k = m + n;
 // }
 // ```
@@ -71,6 +75,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
         syntax::NodeOrToken::Node(n) => n,
         syntax::NodeOrToken::Token(t) => t.parent()?,
     };
+
     let body = extraction_target(&node, range)?;
     let container_info = body.analyze_container(&ctx.sema)?;
 
@@ -86,6 +91,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
 
     let target_range = body.text_range();
 
+    let scope = ImportScope::find_insert_use_container(&node, &ctx.sema)?;
+
     acc.add(
         AssistId("extract_function", crate::AssistKind::RefactorExtract),
         "Extract into function",
@@ -118,10 +125,34 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
 
             let fn_def = format_function(ctx, module, &fun, old_indent, new_indent);
             let insert_offset = insert_after.text_range().end();
+
+            if fn_def.contains("ControlFlow") {
+                let scope = match scope {
+                    ImportScope::File(it) => ImportScope::File(builder.make_mut(it)),
+                    ImportScope::Module(it) => ImportScope::Module(builder.make_mut(it)),
+                    ImportScope::Block(it) => ImportScope::Block(builder.make_mut(it)),
+                };
+
+                let control_flow_enum =
+                    FamousDefs(&ctx.sema, Some(module.krate())).core_ops_ControlFlow();
+
+                if let Some(control_flow_enum) = control_flow_enum {
+                    let mod_path = module.find_use_path_prefixed(
+                        ctx.sema.db,
+                        ModuleDef::from(control_flow_enum),
+                        ctx.config.insert_use.prefix_kind,
+                    );
+
+                    if let Some(mod_path) = mod_path {
+                        insert_use(&scope, mod_path_to_ast(&mod_path), &ctx.config.insert_use);
+                    }
+                }
+            }
+
             match ctx.config.snippet_cap {
                 Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def),
                 None => builder.insert(insert_offset, fn_def),
-            }
+            };
         },
     )
 }
@@ -309,7 +340,7 @@ fn find_local_usages(ctx: &AssistContext, var: Local) -> Self {
         Self(
             Definition::Local(var)
                 .usages(&ctx.sema)
-                .in_scope(SearchScope::single_file(ctx.frange.file_id))
+                .in_scope(SearchScope::single_file(ctx.file_id()))
                 .all(),
         )
     }
@@ -448,11 +479,14 @@ fn from_expr(expr: ast::Expr) -> Option<Self> {
     }
 
     fn from_range(parent: ast::StmtList, selected: TextRange) -> FunctionBody {
-        let mut text_range = parent
-            .statements()
-            .map(|stmt| stmt.syntax().text_range())
-            .filter(|&stmt| selected.intersect(stmt).filter(|it| !it.is_empty()).is_some())
-            .fold1(|acc, stmt| acc.cover(stmt));
+        let full_body = parent.syntax().children_with_tokens();
+
+        let mut text_range = full_body
+            .filter(|it| ast::Stmt::can_cast(it.kind()) || it.kind() == COMMENT)
+            .map(|element| element.text_range())
+            .filter(|&range| selected.intersect(range).filter(|it| !it.is_empty()).is_some())
+            .reduce(|acc, stmt| acc.cover(stmt));
+
         if let Some(tail_range) = parent
             .tail_expr()
             .map(|it| it.syntax().text_range())
@@ -623,7 +657,7 @@ fn analyze(
                         .children_with_tokens()
                         .flat_map(SyntaxElement::into_token)
                         .filter(|it| it.kind() == SyntaxKind::IDENT)
-                        .flat_map(|t| sema.descend_into_macros_many(t))
+                        .flat_map(|t| sema.descend_into_macros(t))
                         .for_each(|t| cb(t.parent().and_then(ast::NameRef::cast)));
                 }
             }
@@ -850,7 +884,7 @@ fn extracted_function_params(
                 // We can move the value into the function call if it's not used after the call,
                 // if the var is not used but defined outside a loop we are extracting from we can't move it either
                 // as the function will reuse it in the next iteration.
-                let move_local = !has_usages && defined_outside_parent_loop;
+                let move_local = (!has_usages && defined_outside_parent_loop) || ty.is_reference();
                 Param { var, ty, move_local, requires_mut, is_copy }
             })
             .collect()
@@ -1039,7 +1073,7 @@ fn is_defined_outside_of_body(
     body: &FunctionBody,
     src: &hir::InFile<Either<ast::IdentPat, ast::SelfParam>>,
 ) -> bool {
-    src.file_id.original_file(ctx.db()) == ctx.frange.file_id
+    src.file_id.original_file(ctx.db()) == ctx.file_id()
         && !body.contains_node(either_syntax(&src.value))
 }
 
@@ -1184,19 +1218,27 @@ fn make_call_expr(&self, call_expr: ast::Expr) -> ast::Expr {
                 let action = action.make_result_handler(None);
                 let stmt = make::expr_stmt(action);
                 let block = make::block_expr(iter::once(stmt.into()), None);
-                let condition = make::condition(call_expr, None);
-                make::expr_if(condition, block, None)
+                let controlflow_break_path = make::path_from_text("ControlFlow::Break");
+                let condition = make::expr_let(
+                    make::tuple_struct_pat(
+                        controlflow_break_path,
+                        iter::once(make::wildcard_pat().into()),
+                    )
+                    .into(),
+                    call_expr,
+                );
+                make::expr_if(condition.into(), block, None)
             }
             FlowHandler::IfOption { action } => {
                 let path = make::ext::ident_path("Some");
                 let value_pat = make::ext::simple_ident_pat(make::name("value"));
                 let pattern = make::tuple_struct_pat(path, iter::once(value_pat.into()));
-                let cond = make::condition(call_expr, Some(pattern.into()));
+                let cond = make::expr_let(pattern.into(), call_expr);
                 let value = make::expr_path(make::ext::ident_path("value"));
                 let action_expr = action.make_result_handler(Some(value));
                 let action_stmt = make::expr_stmt(action_expr);
                 let then = make::block_expr(iter::once(action_stmt.into()), None);
-                make::expr_if(cond, then, None)
+                make::expr_if(cond.into(), then, None)
             }
             FlowHandler::MatchOption { none } => {
                 let some_name = "value";
@@ -1326,7 +1368,7 @@ fn make_ret_ty(&self, ctx: &AssistContext, module: hir::Module) -> Option<ast::R
                     .unwrap_or_else(make::ty_placeholder);
                 make::ext::ty_result(fun_ty.make_ty(ctx, module), handler_ty)
             }
-            FlowHandler::If { .. } => make::ext::ty_bool(),
+            FlowHandler::If { .. } => make::ty("ControlFlow<()>"),
             FlowHandler::IfOption { action } => {
                 let handler_ty = action
                     .expr_ty(ctx)
@@ -1382,6 +1424,7 @@ fn make_body(
     } else {
         FlowHandler::from_ret_ty(fun, &ret_ty)
     };
+
     let block = match &fun.body {
         FunctionBody::Expr(expr) => {
             let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax());
@@ -1403,21 +1446,28 @@ fn make_body(
         FunctionBody::Span { parent, text_range } => {
             let mut elements: Vec<_> = parent
                 .syntax()
-                .children()
+                .children_with_tokens()
                 .filter(|it| text_range.contains_range(it.text_range()))
-                .map(|it| rewrite_body_segment(ctx, &fun.params, &handler, &it))
+                .map(|it| match &it {
+                    syntax::NodeOrToken::Node(n) => syntax::NodeOrToken::Node(
+                        rewrite_body_segment(ctx, &fun.params, &handler, &n),
+                    ),
+                    _ => it,
+                })
                 .collect();
 
-            let mut tail_expr = match elements.pop() {
-                Some(node) => ast::Expr::cast(node.clone()).or_else(|| {
-                    elements.push(node);
-                    None
-                }),
-                None => None,
+            let mut tail_expr = match &elements.last() {
+                Some(syntax::NodeOrToken::Node(node)) if ast::Expr::can_cast(node.kind()) => {
+                    ast::Expr::cast(node.clone())
+                }
+                _ => None,
             };
 
-            if tail_expr.is_none() {
-                match fun.outliving_locals.as_slice() {
+            match tail_expr {
+                Some(_) => {
+                    elements.pop();
+                }
+                None => match fun.outliving_locals.as_slice() {
                     [] => {}
                     [var] => {
                         tail_expr = Some(path_expr_from_local(ctx, var.local));
@@ -1427,22 +1477,27 @@ fn make_body(
                         let expr = make::expr_tuple(exprs);
                         tail_expr = Some(expr);
                     }
-                }
-            }
-
-            let elements = elements.into_iter().filter_map(|node| match ast::Stmt::cast(node) {
-                Some(stmt) => Some(stmt),
-                None => {
-                    stdx::never!("block contains non-statement");
-                    None
-                }
-            });
+                },
+            };
 
             let body_indent = IndentLevel(1);
-            let elements = elements.map(|stmt| stmt.dedent(old_indent).indent(body_indent));
+            let elements = elements
+                .into_iter()
+                .map(|node_or_token| match &node_or_token {
+                    syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) {
+                        Some(stmt) => {
+                            let indented = stmt.dedent(old_indent).indent(body_indent);
+                            let ast_node = indented.syntax().clone_subtree();
+                            syntax::NodeOrToken::Node(ast_node)
+                        }
+                        _ => node_or_token,
+                    },
+                    _ => node_or_token,
+                })
+                .collect::<Vec<SyntaxElement>>();
             let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent));
 
-            make::block_expr(elements, tail_expr)
+            make::hacky_block_expr_with_comments(elements, tail_expr)
         }
     };
 
@@ -1461,8 +1516,11 @@ fn make_body(
             })
         }
         FlowHandler::If { .. } => {
-            let lit_false = make::expr_literal("false");
-            with_tail_expr(block, lit_false.into())
+            let controlflow_continue = make::expr_call(
+                make::expr_path(make::path_from_text("ControlFlow::Continue")),
+                make::arg_list(iter::once(make::expr_unit())),
+            );
+            with_tail_expr(block, controlflow_continue.into())
         }
         FlowHandler::IfOption { .. } => {
             let none = make::expr_path(make::ext::ident_path("None"));
@@ -1638,7 +1696,10 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) {
 fn make_rewritten_flow(handler: &FlowHandler, arg_expr: Option<ast::Expr>) -> Option<ast::Expr> {
     let value = match handler {
         FlowHandler::None | FlowHandler::Try { .. } => return None,
-        FlowHandler::If { .. } => make::expr_literal("true").into(),
+        FlowHandler::If { .. } => make::expr_call(
+            make::expr_path(make::path_from_text("ControlFlow::Break")),
+            make::arg_list(iter::once(make::expr_unit())),
+        ),
         FlowHandler::IfOption { .. } => {
             let expr = arg_expr.unwrap_or_else(|| make::expr_tuple(Vec::new()));
             let args = make::arg_list(iter::once(expr));
@@ -3270,6 +3331,7 @@ fn break_loop_with_if() {
         check_assist(
             extract_function,
             r#"
+//- minicore: try
 fn foo() {
     loop {
         let mut n = 1;
@@ -3281,21 +3343,23 @@ fn foo() {
 }
 "#,
             r#"
+use core::ops::ControlFlow;
+
 fn foo() {
     loop {
         let mut n = 1;
-        if fun_name(&mut n) {
+        if let ControlFlow::Break(_) = fun_name(&mut n) {
             break;
         }
         let h = 1 + n;
     }
 }
 
-fn $0fun_name(n: &mut i32) -> bool {
+fn $0fun_name(n: &mut i32) -> ControlFlow<()> {
     let m = *n + 1;
-    return true;
+    return ControlFlow::Break(());
     *n += m;
-    false
+    ControlFlow::Continue(())
 }
 "#,
         );
@@ -3306,6 +3370,7 @@ fn break_loop_nested() {
         check_assist(
             extract_function,
             r#"
+//- minicore: try
 fn foo() {
     loop {
         let mut n = 1;
@@ -3318,22 +3383,24 @@ fn foo() {
 }
 "#,
             r#"
+use core::ops::ControlFlow;
+
 fn foo() {
     loop {
         let mut n = 1;
-        if fun_name(n) {
+        if let ControlFlow::Break(_) = fun_name(n) {
             break;
         }
         let h = 1;
     }
 }
 
-fn $0fun_name(n: i32) -> bool {
+fn $0fun_name(n: i32) -> ControlFlow<()> {
     let m = n + 1;
     if m == 42 {
-        return true;
+        return ControlFlow::Break(());
     }
-    false
+    ControlFlow::Continue(())
 }
 "#,
         );
@@ -4042,13 +4109,34 @@ fn foo() {
 "#,
             r#"
 fn foo() {
-    /**/
     fun_name();
-    /**/
 }
 
 fn $0fun_name() {
+    /**/
+    foo();
     foo();
+    /**/
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_does_not_tear_body_apart() {
+        check_assist(
+            extract_function,
+            r#"
+fn foo() {
+    $0foo();
+}$0
+"#,
+            r#"
+fn foo() {
+    fun_name();
+}
+
+fn $0fun_name() {
     foo();
 }
 "#,
@@ -4282,6 +4370,226 @@ fn foo() {
 fn $0fun_name(a: _) -> _ {
     a
 }
+"#,
+        );
+    }
+
+    #[test]
+    fn reference_mutable_param_with_further_usages() {
+        check_assist(
+            extract_function,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    $0arg.field = 8;$0
+    // Simulating access after the extracted portion
+    arg.field = 16;
+}
+"#,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    fun_name(arg);
+    // Simulating access after the extracted portion
+    arg.field = 16;
+}
+
+fn $0fun_name(arg: &mut Foo) {
+    arg.field = 8;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn reference_mutable_param_without_further_usages() {
+        check_assist(
+            extract_function,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    $0arg.field = 8;$0
+}
+"#,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    fun_name(arg);
+}
+
+fn $0fun_name(arg: &mut Foo) {
+    arg.field = 8;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_at_start() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0// comment here!
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    // comment here!
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_in_between() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;$0
+    let a = 0;
+    // comment here!
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let a = 0;
+    // comment here!
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_at_end() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let x = 0;
+    // comment here!$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let x = 0;
+    // comment here!
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_indented() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let x = 0;
+    while(true) {
+        // comment here!
+    }$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let x = 0;
+    while(true) {
+        // comment here!
+    }
+}
+"#,
+        );
+    }
+
+    // FIXME: we do want to preserve whitespace
+    #[test]
+    fn extract_function_does_not_preserve_whitespace() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let a = 0;
+
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let a = 0;
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_long_form_comment() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0/* a comment */
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    /* a comment */
+    let x = 0;
+}
 "#,
         );
     }