]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/extract_variable.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
index 176f52aa4d4130b9c31aeff9f8f8d2fda7bf40d9..aaed2b67fe8f56b087a5b5762e8d41129107411a 100644 (file)
@@ -1,6 +1,7 @@
 use stdx::format_to;
 use syntax::{
     ast::{self, AstNode},
+    NodeOrToken,
     SyntaxKind::{
         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
         PATH_EXPR, RETURN_EXPR,
 // }
 // ```
 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
-    if ctx.frange.range.is_empty() {
-        return None;
-    }
-    let node = ctx.covering_element();
-    if node.kind() == COMMENT {
-        cov_mark::hit!(extract_var_in_comment_is_not_applicable);
+    if ctx.has_empty_selection() {
         return None;
     }
+
+    let node = match ctx.covering_element() {
+        NodeOrToken::Node(it) => it,
+        NodeOrToken::Token(it) if it.kind() == COMMENT => {
+            cov_mark::hit!(extract_var_in_comment_is_not_applicable);
+            return None;
+        }
+        NodeOrToken::Token(it) => it.parent()?,
+    };
+    let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?;
     let to_extract = node
-        .ancestors()
-        .take_while(|it| it.text_range().contains_range(ctx.frange.range))
+        .descendants()
+        .take_while(|it| ctx.selection_trimmed().contains_range(it.text_range()))
         .find_map(valid_target_expr)?;
-    if let Some(ty) = ctx.sema.type_of_expr(&to_extract) {
-        if ty.is_unit() {
+
+    if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
+        if ty_info.adjusted().is_unit() {
             return None;
         }
     }
+
+    let reference_modifier = match get_receiver_type(&ctx, &to_extract) {
+        Some(receiver_type) if receiver_type.is_mutable_reference() => "&mut ",
+        Some(receiver_type) if receiver_type.is_reference() => "&",
+        _ => "",
+    };
+
     let anchor = Anchor::from(&to_extract)?;
     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
     let target = to_extract.syntax().text_range();
@@ -69,17 +83,20 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option
                 None => to_extract.syntax().text_range(),
             };
 
-            if let Anchor::WrapInBlock(_) = anchor {
-                format_to!(buf, "{{ let {} = ", var_name);
-            } else {
-                format_to!(buf, "let {} = ", var_name);
+            match anchor {
+                Anchor::Before(_) | Anchor::Replace(_) => {
+                    format_to!(buf, "let {} = {}", var_name, reference_modifier)
+                }
+                Anchor::WrapInBlock(_) => {
+                    format_to!(buf, "{{ let {} = {}", var_name, reference_modifier)
+                }
             };
             format_to!(buf, "{}", to_extract.syntax());
 
             if let Anchor::Replace(stmt) = anchor {
                 cov_mark::hit!(test_extract_var_expr_stmt);
                 if stmt.semicolon_token().is_none() {
-                    buf.push_str(";");
+                    buf.push(';');
                 }
                 match ctx.config.snippet_cap {
                     Some(cap) => {
@@ -92,7 +109,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option
                 return;
             }
 
-            buf.push_str(";");
+            buf.push(';');
 
             // We want to maintain the indent level,
             // but we do not want to duplicate possible
@@ -137,6 +154,23 @@ fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
     }
 }
 
+fn get_receiver_type(ctx: &AssistContext, expression: &ast::Expr) -> Option<hir::Type> {
+    let receiver = get_receiver(expression.clone())?;
+    Some(ctx.sema.type_of_expr(&receiver)?.original())
+}
+
+/// In the expression `a.b.c.x()`, find `a`
+fn get_receiver(expression: ast::Expr) -> Option<ast::Expr> {
+    match expression {
+        ast::Expr::FieldExpr(field) if field.expr().is_some() => {
+            let nested_expression = &field.expr()?;
+            get_receiver(nested_expression.to_owned())
+        }
+        _ => Some(expression),
+    }
+}
+
+#[derive(Debug)]
 enum Anchor {
     Before(SyntaxNode),
     Replace(ast::ExprStmt),
@@ -145,10 +179,16 @@ enum Anchor {
 
 impl Anchor {
     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
-        to_extract.syntax().ancestors().take_while(|it| !ast::Item::can_cast(it.kind())).find_map(
-            |node| {
+        to_extract
+            .syntax()
+            .ancestors()
+            .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind()))
+            .find_map(|node| {
+                if ast::MacroCall::can_cast(node.kind()) {
+                    return None;
+                }
                 if let Some(expr) =
-                    node.parent().and_then(ast::BlockExpr::cast).and_then(|it| it.tail_expr())
+                    node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr())
                 {
                     if expr.syntax() == &node {
                         cov_mark::hit!(test_extract_var_last_expr);
@@ -180,8 +220,7 @@ fn from(to_extract: &ast::Expr) -> Option<Anchor> {
                     return Some(Anchor::Before(node));
                 }
                 None
-            },
-        )
+            })
     }
 
     fn syntax(&self) -> &SyntaxNode {
@@ -227,7 +266,7 @@ fn test_extract_var_expr_stmt() {
             extract_variable,
             r#"
 fn foo() {
-    $01 + 1$0;
+  $0  1 + 1$0;
 }"#,
             r#"
 fn foo() {
@@ -236,12 +275,12 @@ fn foo() {
         );
         check_assist(
             extract_variable,
-            "
+            r"
 fn foo() {
     $0{ let x = 0; x }$0
     something_else();
 }",
-            "
+            r"
 fn foo() {
     let $0var_name = { let x = 0; x };
     something_else();
@@ -253,11 +292,11 @@ fn foo() {
     fn test_extract_var_part_of_expr_stmt() {
         check_assist(
             extract_variable,
-            "
+            r"
 fn foo() {
     $01$0 + 1;
 }",
-            "
+            r"
 fn foo() {
     let $0var_name = 1;
     var_name + 1;
@@ -804,6 +843,32 @@ fn foo() {
         )
     }
 
+    #[test]
+    fn extract_macro_call() {
+        check_assist(
+            extract_variable,
+            r"
+struct Vec;
+macro_rules! vec {
+    () => {Vec}
+}
+fn main() {
+    let _ = $0vec![]$0;
+}
+",
+            r"
+struct Vec;
+macro_rules! vec {
+    () => {Vec}
+}
+fn main() {
+    let $0vec = vec![];
+    let _ = vec;
+}
+",
+        );
+    }
+
     #[test]
     fn test_extract_var_for_return_not_applicable() {
         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
@@ -859,4 +924,330 @@ fn extract_var_no_block_body() {
 ",
         );
     }
+
+    #[test]
+    fn test_extract_var_mutable_reference_parameter() {
+        check_assist(
+            extract_variable,
+            r#"
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(s: &mut S) {
+    $0s.vec$0.push(0);
+}"#,
+            r#"
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(s: &mut S) {
+    let $0var_name = &mut s.vec;
+    var_name.push(0);
+}"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_var_mutable_reference_parameter_deep_nesting() {
+        check_assist(
+            extract_variable,
+            r#"
+struct Y {
+    field: X
+}
+struct X {
+    field: S
+}
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(f: &mut Y) {
+    $0f.field.field.vec$0.push(0);
+}"#,
+            r#"
+struct Y {
+    field: X
+}
+struct X {
+    field: S
+}
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(f: &mut Y) {
+    let $0var_name = &mut f.field.field.vec;
+    var_name.push(0);
+}"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_var_reference_parameter() {
+        check_assist(
+            extract_variable,
+            r#"
+struct X;
+
+impl X {
+    fn do_thing(&self) {
+
+    }
+}
+
+struct S {
+    sub: X
+}
+
+fn foo(s: &S) {
+    $0s.sub$0.do_thing();
+}"#,
+            r#"
+struct X;
+
+impl X {
+    fn do_thing(&self) {
+
+    }
+}
+
+struct S {
+    sub: X
+}
+
+fn foo(s: &S) {
+    let $0x = &s.sub;
+    x.do_thing();
+}"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_var_reference_parameter_deep_nesting() {
+        check_assist(
+            extract_variable,
+            r#"
+struct Z;
+impl Z {
+    fn do_thing(&self) {
+
+    }
+}
+
+struct Y {
+    field: Z
+}
+
+struct X {
+    field: Y
+}
+
+struct S {
+    sub: X
+}
+
+fn foo(s: &S) {
+    $0s.sub.field.field$0.do_thing();
+}"#,
+            r#"
+struct Z;
+impl Z {
+    fn do_thing(&self) {
+
+    }
+}
+
+struct Y {
+    field: Z
+}
+
+struct X {
+    field: Y
+}
+
+struct S {
+    sub: X
+}
+
+fn foo(s: &S) {
+    let $0z = &s.sub.field.field;
+    z.do_thing();
+}"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_var_regular_parameter() {
+        check_assist(
+            extract_variable,
+            r#"
+struct X;
+
+impl X {
+    fn do_thing(&self) {
+
+    }
+}
+
+struct S {
+    sub: X
+}
+
+fn foo(s: S) {
+    $0s.sub$0.do_thing();
+}"#,
+            r#"
+struct X;
+
+impl X {
+    fn do_thing(&self) {
+
+    }
+}
+
+struct S {
+    sub: X
+}
+
+fn foo(s: S) {
+    let $0x = s.sub;
+    x.do_thing();
+}"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_var_mutable_reference_local() {
+        check_assist(
+            extract_variable,
+            r#"
+struct X;
+
+struct S {
+    sub: X
+}
+
+impl S {
+    fn new() -> S {
+        S {
+            sub: X::new()
+        }
+    }
+}
+
+impl X {
+    fn new() -> X {
+        X { }
+    }
+    fn do_thing(&self) {
+
+    }
+}
+
+
+fn foo() {
+    let local = &mut S::new();
+    $0local.sub$0.do_thing();
+}"#,
+            r#"
+struct X;
+
+struct S {
+    sub: X
+}
+
+impl S {
+    fn new() -> S {
+        S {
+            sub: X::new()
+        }
+    }
+}
+
+impl X {
+    fn new() -> X {
+        X { }
+    }
+    fn do_thing(&self) {
+
+    }
+}
+
+
+fn foo() {
+    let local = &mut S::new();
+    let $0x = &mut local.sub;
+    x.do_thing();
+}"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_var_reference_local() {
+        check_assist(
+            extract_variable,
+            r#"
+struct X;
+
+struct S {
+    sub: X
+}
+
+impl S {
+    fn new() -> S {
+        S {
+            sub: X::new()
+        }
+    }
+}
+
+impl X {
+    fn new() -> X {
+        X { }
+    }
+    fn do_thing(&self) {
+
+    }
+}
+
+
+fn foo() {
+    let local = &S::new();
+    $0local.sub$0.do_thing();
+}"#,
+            r#"
+struct X;
+
+struct S {
+    sub: X
+}
+
+impl S {
+    fn new() -> S {
+        S {
+            sub: X::new()
+        }
+    }
+}
+
+impl X {
+    fn new() -> X {
+        X { }
+    }
+    fn do_thing(&self) {
+
+    }
+}
+
+
+fn foo() {
+    let local = &S::new();
+    let $0x = &local.sub;
+    x.do_thing();
+}"#,
+        );
+    }
 }