]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/extract_function.rs
Merge #11842
[rust.git] / crates / ide_assists / src / handlers / extract_function.rs
index 3a334efe0ab9385ce9b2a26ff5f9cb4b12e0bd4b..fcb4aaf065d4bab0269d46b4aa9f2e678c0da7b4 100644 (file)
@@ -1,21 +1,18 @@
-use std::{hash::BuildHasherDefault, iter};
+use std::iter;
 
 use ast::make;
 use either::Either;
 use hir::{HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
 use ide_db::{
     defs::{Definition, NameRefClass},
-    helpers::{
-        insert_use::{insert_use, ImportScope},
-        mod_path_to_ast,
-        node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
-        FamousDefs,
-    },
+    famous_defs::FamousDefs,
+    helpers::mod_path_to_ast,
+    imports::insert_use::{insert_use, ImportScope},
     search::{FileReference, ReferenceCategory, SearchScope},
-    RootDatabase,
+    syntax_helpers::node_ext::{preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr},
+    FxIndexSet, RootDatabase,
 };
 use itertools::Itertools;
-use rustc_hash::FxHasher;
 use stdx::format_to;
 use syntax::{
     ast::{
     AssistId,
 };
 
-type FxIndexSet<T> = indexmap::IndexSet<T, BuildHasherDefault<FxHasher>>;
-
 // Assist: extract_function
 //
-// Extracts selected statements into new function.
+// Extracts selected statements and comments into new function.
 //
 // ```
 // fn main() {
 //     let n = 1;
 //     $0let m = n + 2;
+//     // calculate
 //     let k = m + n;$0
 //     let g = 3;
 // }
@@ -57,6 +53,7 @@
 //
 // fn $0fun_name(n: i32) {
 //     let m = n + 2;
+//     // calculate
 //     let k = m + n;
 // }
 // ```
@@ -76,6 +73,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
         syntax::NodeOrToken::Node(n) => n,
         syntax::NodeOrToken::Token(t) => t.parent()?,
     };
+
     let body = extraction_target(&node, range)?;
     let container_info = body.analyze_container(&ctx.sema)?;
 
@@ -289,10 +287,10 @@ enum FlowKind {
     Try {
         kind: TryKind,
     },
-    /// Break with value (`break $expr;`)
-    Break(Option<ast::Expr>),
-    /// Continue
-    Continue,
+    /// Break with label and value (`break 'label $expr;`)
+    Break(Option<ast::Lifetime>, Option<ast::Expr>),
+    /// Continue with label (`continue 'label;`)
+    Continue(Option<ast::Lifetime>),
 }
 
 #[derive(Debug, Clone)]
@@ -393,7 +391,7 @@ fn to_arg(&self, ctx: &AssistContext) -> ast::Expr {
     }
 
     fn to_param(&self, ctx: &AssistContext, module: hir::Module) -> ast::Param {
-        let var = self.var.name(ctx.db()).unwrap().to_string();
+        let var = self.var.name(ctx.db()).to_string();
         let var_name = make::name(&var);
         let pat = match self.kind() {
             ParamKind::MutValue => make::ident_pat(false, true, var_name),
@@ -435,21 +433,21 @@ impl FlowKind {
     fn make_result_handler(&self, expr: Option<ast::Expr>) -> ast::Expr {
         match self {
             FlowKind::Return(_) => make::expr_return(expr),
-            FlowKind::Break(_) => make::expr_break(expr),
+            FlowKind::Break(label, _) => make::expr_break(label.clone(), expr),
             FlowKind::Try { .. } => {
                 stdx::never!("cannot have result handler with try");
                 expr.unwrap_or_else(|| make::expr_return(None))
             }
-            FlowKind::Continue => {
+            FlowKind::Continue(label) => {
                 stdx::always!(expr.is_none(), "continue with value is not possible");
-                make::expr_continue()
+                make::expr_continue(label.clone())
             }
         }
     }
 
     fn expr_ty(&self, ctx: &AssistContext) -> Option<hir::Type> {
         match self {
-            FlowKind::Return(Some(expr)) | FlowKind::Break(Some(expr)) => {
+            FlowKind::Return(Some(expr)) | FlowKind::Break(_, Some(expr)) => {
                 ctx.sema.type_of_expr(expr).map(TypeInfo::adjusted)
             }
             FlowKind::Try { .. } => {
@@ -479,11 +477,14 @@ fn from_expr(expr: ast::Expr) -> Option<Self> {
     }
 
     fn from_range(parent: ast::StmtList, selected: TextRange) -> FunctionBody {
-        let mut text_range = parent
-            .statements()
-            .map(|stmt| stmt.syntax().text_range())
-            .filter(|&stmt| selected.intersect(stmt).filter(|it| !it.is_empty()).is_some())
-            .fold1(|acc, stmt| acc.cover(stmt));
+        let full_body = parent.syntax().children_with_tokens();
+
+        let mut text_range = full_body
+            .filter(|it| ast::Stmt::can_cast(it.kind()) || it.kind() == COMMENT)
+            .map(|element| element.text_range())
+            .filter(|&range| selected.intersect(range).filter(|it| !it.is_empty()).is_some())
+            .reduce(|acc, stmt| acc.cover(stmt));
+
         if let Some(tail_range) = parent
             .tail_expr()
             .map(|it| it.syntax().text_range())
@@ -838,8 +839,8 @@ fn external_control_flow(
                 cov_mark::hit!(external_control_flow_break_and_continue);
                 return None;
             }
-            (None, None, Some(b), None) => Some(FlowKind::Break(b.expr())),
-            (None, None, None, Some(_)) => Some(FlowKind::Continue),
+            (None, None, Some(b), None) => Some(FlowKind::Break(b.lifetime(), b.expr())),
+            (None, None, None, Some(c)) => Some(FlowKind::Continue(c.lifetime())),
             (None, None, None, None) => None,
         };
 
@@ -881,7 +882,7 @@ fn extracted_function_params(
                 // We can move the value into the function call if it's not used after the call,
                 // if the var is not used but defined outside a loop we are extracting from we can't move it either
                 // as the function will reuse it in the next iteration.
-                let move_local = !has_usages && defined_outside_parent_loop;
+                let move_local = (!has_usages && defined_outside_parent_loop) || ty.is_reference();
                 Param { var, ty, move_local, requires_mut, is_copy }
             })
             .collect()
@@ -1141,12 +1142,12 @@ fn make_call(ctx: &AssistContext, fun: &Function, indent: IndentLevel) -> String
     match fun.outliving_locals.as_slice() {
         [] => {}
         [var] => {
-            format_to!(buf, "let {}{} = ", mut_modifier(var), var.local.name(ctx.db()).unwrap())
+            format_to!(buf, "let {}{} = ", mut_modifier(var), var.local.name(ctx.db()))
         }
         vars => {
             buf.push_str("let (");
             let bindings = vars.iter().format_with(", ", |local, f| {
-                f(&format_args!("{}{}", mut_modifier(local), local.local.name(ctx.db()).unwrap()))
+                f(&format_args!("{}{}", mut_modifier(local), local.local.name(ctx.db())))
             });
             format_to!(buf, "{}", bindings);
             buf.push_str(") = ");
@@ -1184,20 +1185,20 @@ fn from_ret_ty(fun: &Function, ret_ty: &FunType) -> FlowHandler {
                 let action = flow_kind.clone();
                 if *ret_ty == FunType::Unit {
                     match flow_kind {
-                        FlowKind::Return(None) | FlowKind::Break(None) | FlowKind::Continue => {
-                            FlowHandler::If { action }
-                        }
-                        FlowKind::Return(_) | FlowKind::Break(_) => {
+                        FlowKind::Return(None)
+                        | FlowKind::Break(_, None)
+                        | FlowKind::Continue(_) => FlowHandler::If { action },
+                        FlowKind::Return(_) | FlowKind::Break(_, _) => {
                             FlowHandler::IfOption { action }
                         }
                         FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() },
                     }
                 } else {
                     match flow_kind {
-                        FlowKind::Return(None) | FlowKind::Break(None) | FlowKind::Continue => {
-                            FlowHandler::MatchOption { none: action }
-                        }
-                        FlowKind::Return(_) | FlowKind::Break(_) => {
+                        FlowKind::Return(None)
+                        | FlowKind::Break(_, None)
+                        | FlowKind::Continue(_) => FlowHandler::MatchOption { none: action },
+                        FlowKind::Return(_) | FlowKind::Break(_, _) => {
                             FlowHandler::MatchResult { err: action }
                         }
                         FlowKind::Try { kind } => FlowHandler::Try { kind: kind.clone() },
@@ -1216,28 +1217,26 @@ fn make_call_expr(&self, call_expr: ast::Expr) -> ast::Expr {
                 let stmt = make::expr_stmt(action);
                 let block = make::block_expr(iter::once(stmt.into()), None);
                 let controlflow_break_path = make::path_from_text("ControlFlow::Break");
-                let condition = make::condition(
+                let condition = make::expr_let(
+                    make::tuple_struct_pat(
+                        controlflow_break_path,
+                        iter::once(make::wildcard_pat().into()),
+                    )
+                    .into(),
                     call_expr,
-                    Some(
-                        make::tuple_struct_pat(
-                            controlflow_break_path,
-                            iter::once(make::wildcard_pat().into()),
-                        )
-                        .into(),
-                    ),
                 );
-                make::expr_if(condition, block, None)
+                make::expr_if(condition.into(), block, None)
             }
             FlowHandler::IfOption { action } => {
                 let path = make::ext::ident_path("Some");
                 let value_pat = make::ext::simple_ident_pat(make::name("value"));
                 let pattern = make::tuple_struct_pat(path, iter::once(value_pat.into()));
-                let cond = make::condition(call_expr, Some(pattern.into()));
+                let cond = make::expr_let(pattern.into(), call_expr);
                 let value = make::expr_path(make::ext::ident_path("value"));
                 let action_expr = action.make_result_handler(Some(value));
                 let action_stmt = make::expr_stmt(action_expr);
                 let then = make::block_expr(iter::once(action_stmt.into()), None);
-                make::expr_if(cond, then, None)
+                make::expr_if(cond.into(), then, None)
             }
             FlowHandler::MatchOption { none } => {
                 let some_name = "value";
@@ -1287,7 +1286,7 @@ fn make_call_expr(&self, call_expr: ast::Expr) -> ast::Expr {
 }
 
 fn path_expr_from_local(ctx: &AssistContext, var: Local) -> ast::Expr {
-    let name = var.name(ctx.db()).unwrap().to_string();
+    let name = var.name(ctx.db()).to_string();
     make::expr_path(make::ext::ident_path(&name))
 }
 
@@ -1423,6 +1422,7 @@ fn make_body(
     } else {
         FlowHandler::from_ret_ty(fun, &ret_ty)
     };
+
     let block = match &fun.body {
         FunctionBody::Expr(expr) => {
             let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax());
@@ -1444,21 +1444,28 @@ fn make_body(
         FunctionBody::Span { parent, text_range } => {
             let mut elements: Vec<_> = parent
                 .syntax()
-                .children()
+                .children_with_tokens()
                 .filter(|it| text_range.contains_range(it.text_range()))
-                .map(|it| rewrite_body_segment(ctx, &fun.params, &handler, &it))
+                .map(|it| match &it {
+                    syntax::NodeOrToken::Node(n) => syntax::NodeOrToken::Node(
+                        rewrite_body_segment(ctx, &fun.params, &handler, n),
+                    ),
+                    _ => it,
+                })
                 .collect();
 
-            let mut tail_expr = match elements.pop() {
-                Some(node) => ast::Expr::cast(node.clone()).or_else(|| {
-                    elements.push(node);
-                    None
-                }),
-                None => None,
+            let mut tail_expr = match &elements.last() {
+                Some(syntax::NodeOrToken::Node(node)) if ast::Expr::can_cast(node.kind()) => {
+                    ast::Expr::cast(node.clone())
+                }
+                _ => None,
             };
 
-            if tail_expr.is_none() {
-                match fun.outliving_locals.as_slice() {
+            match tail_expr {
+                Some(_) => {
+                    elements.pop();
+                }
+                None => match fun.outliving_locals.as_slice() {
                     [] => {}
                     [var] => {
                         tail_expr = Some(path_expr_from_local(ctx, var.local));
@@ -1468,22 +1475,27 @@ fn make_body(
                         let expr = make::expr_tuple(exprs);
                         tail_expr = Some(expr);
                     }
-                }
-            }
-
-            let elements = elements.into_iter().filter_map(|node| match ast::Stmt::cast(node) {
-                Some(stmt) => Some(stmt),
-                None => {
-                    stdx::never!("block contains non-statement");
-                    None
-                }
-            });
+                },
+            };
 
             let body_indent = IndentLevel(1);
-            let elements = elements.map(|stmt| stmt.dedent(old_indent).indent(body_indent));
+            let elements = elements
+                .into_iter()
+                .map(|node_or_token| match &node_or_token {
+                    syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) {
+                        Some(stmt) => {
+                            let indented = stmt.dedent(old_indent).indent(body_indent);
+                            let ast_node = indented.syntax().clone_subtree();
+                            syntax::NodeOrToken::Node(ast_node)
+                        }
+                        _ => node_or_token,
+                    },
+                    _ => node_or_token,
+                })
+                .collect::<Vec<SyntaxElement>>();
             let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent));
 
-            make::block_expr(elements, tail_expr)
+            make::hacky_block_expr_with_comments(elements, tail_expr)
         }
     };
 
@@ -1506,7 +1518,7 @@ fn make_body(
                 make::expr_path(make::path_from_text("ControlFlow::Continue")),
                 make::arg_list(iter::once(make::expr_unit())),
             );
-            with_tail_expr(block, controlflow_continue.into())
+            with_tail_expr(block, controlflow_continue)
         }
         FlowHandler::IfOption { .. } => {
             let none = make::expr_path(make::ext::ident_path("None"));
@@ -3392,6 +3404,76 @@ fn $0fun_name(n: i32) -> ControlFlow<()> {
         );
     }
 
+    #[test]
+    fn break_loop_nested_labeled() {
+        check_assist(
+            extract_function,
+            r#"
+//- minicore: try
+fn foo() {
+    'bar: loop {
+        loop {
+            $0break 'bar;$0
+        }
+    }
+}
+"#,
+            r#"
+use core::ops::ControlFlow;
+
+fn foo() {
+    'bar: loop {
+        loop {
+            if let ControlFlow::Break(_) = fun_name() {
+                break 'bar;
+            }
+        }
+    }
+}
+
+fn $0fun_name() -> ControlFlow<()> {
+    return ControlFlow::Break(());
+    ControlFlow::Continue(())
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn continue_loop_nested_labeled() {
+        check_assist(
+            extract_function,
+            r#"
+//- minicore: try
+fn foo() {
+    'bar: loop {
+        loop {
+            $0continue 'bar;$0
+        }
+    }
+}
+"#,
+            r#"
+use core::ops::ControlFlow;
+
+fn foo() {
+    'bar: loop {
+        loop {
+            if let ControlFlow::Break(_) = fun_name() {
+                continue 'bar;
+            }
+        }
+    }
+}
+
+fn $0fun_name() -> ControlFlow<()> {
+    return ControlFlow::Break(());
+    ControlFlow::Continue(())
+}
+"#,
+        );
+    }
+
     #[test]
     fn return_from_nested_loop() {
         check_assist(
@@ -3596,6 +3678,46 @@ fn $0fun_name() -> Option<i32> {
         );
     }
 
+    #[test]
+    fn break_with_value_and_label() {
+        check_assist(
+            extract_function,
+            r#"
+fn foo() -> i32 {
+    'bar: loop {
+        let n = 1;
+        $0let k = 1;
+        if k == 42 {
+            break 'bar 4;
+        }
+        let m = k + 1;$0
+        let h = 1;
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    'bar: loop {
+        let n = 1;
+        if let Some(value) = fun_name() {
+            break 'bar value;
+        }
+        let h = 1;
+    }
+}
+
+fn $0fun_name() -> Option<i32> {
+    let k = 1;
+    if k == 42 {
+        return Some(4);
+    }
+    let m = k + 1;
+    None
+}
+"#,
+        );
+    }
+
     #[test]
     fn break_with_value_and_return() {
         check_assist(
@@ -4040,7 +4162,7 @@ fn main() {
     match 6 {
         100 => $0{ 100 }$0
         _ => 0,
-    }
+    };
 }
 "#,
             r#"
@@ -4048,7 +4170,7 @@ fn main() {
     match 6 {
         100 => fun_name(),
         _ => 0,
-    }
+    };
 }
 
 fn $0fun_name() -> i32 {
@@ -4063,7 +4185,7 @@ fn main() {
     match 6 {
         100 => $0{ 100 }$0,
         _ => 0,
-    }
+    };
 }
 "#,
             r#"
@@ -4071,7 +4193,7 @@ fn main() {
     match 6 {
         100 => fun_name(),
         _ => 0,
-    }
+    };
 }
 
 fn $0fun_name() -> i32 {
@@ -4095,14 +4217,35 @@ fn foo() {
 "#,
             r#"
 fn foo() {
-    /**/
     fun_name();
-    /**/
 }
 
 fn $0fun_name() {
+    /**/
     foo();
     foo();
+    /**/
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_does_not_tear_body_apart() {
+        check_assist(
+            extract_function,
+            r#"
+fn foo() {
+    $0foo();
+}$0
+"#,
+            r#"
+fn foo() {
+    fun_name();
+}
+
+fn $0fun_name() {
+    foo();
 }
 "#,
         );
@@ -4335,6 +4478,226 @@ fn foo() {
 fn $0fun_name(a: _) -> _ {
     a
 }
+"#,
+        );
+    }
+
+    #[test]
+    fn reference_mutable_param_with_further_usages() {
+        check_assist(
+            extract_function,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    $0arg.field = 8;$0
+    // Simulating access after the extracted portion
+    arg.field = 16;
+}
+"#,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    fun_name(arg);
+    // Simulating access after the extracted portion
+    arg.field = 16;
+}
+
+fn $0fun_name(arg: &mut Foo) {
+    arg.field = 8;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn reference_mutable_param_without_further_usages() {
+        check_assist(
+            extract_function,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    $0arg.field = 8;$0
+}
+"#,
+            r#"
+pub struct Foo {
+    field: u32,
+}
+
+pub fn testfn(arg: &mut Foo) {
+    fun_name(arg);
+}
+
+fn $0fun_name(arg: &mut Foo) {
+    arg.field = 8;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_at_start() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0// comment here!
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    // comment here!
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_in_between() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;$0
+    let a = 0;
+    // comment here!
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let a = 0;
+    // comment here!
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_at_end() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let x = 0;
+    // comment here!$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let x = 0;
+    // comment here!
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_copies_comment_indented() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let x = 0;
+    while(true) {
+        // comment here!
+    }$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let x = 0;
+    while(true) {
+        // comment here!
+    }
+}
+"#,
+        );
+    }
+
+    // FIXME: we do want to preserve whitespace
+    #[test]
+    fn extract_function_does_not_preserve_whitespace() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0let a = 0;
+
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    let a = 0;
+    let x = 0;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn extract_function_long_form_comment() {
+        check_assist(
+            extract_function,
+            r#"
+fn func() {
+    let i = 0;
+    $0/* a comment */
+    let x = 0;$0
+}
+"#,
+            r#"
+fn func() {
+    let i = 0;
+    fun_name();
+}
+
+fn $0fun_name() {
+    /* a comment */
+    let x = 0;
+}
 "#,
         );
     }