]> git.lizzy.rs Git - rust.git/commitdiff
correctly handle mutable references
authorJeroen Vannevel <jer_vannevel@outlook.com>
Wed, 5 Jan 2022 01:03:27 +0000 (01:03 +0000)
committerJeroen Vannevel <jer_vannevel@outlook.com>
Wed, 5 Jan 2022 01:03:27 +0000 (01:03 +0000)
crates/ide_assists/src/handlers/extract_variable.rs

index 7a57ab3b9b7cc2b45722653605c480e6881d3ed3..d7a8e1dd4c2bc543a4e25eb214ae36ac4fdcf2a9 100644 (file)
@@ -52,6 +52,12 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option
         }
     }
 
+    let is_mutable_reference = if let Some(receiver_type) = get_receiver_type(&ctx, &to_extract) {
+        receiver_type.is_mutable_reference()
+    } else {
+        false
+    };
+
     let anchor = Anchor::from(&to_extract)?;
     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
     let target = to_extract.syntax().text_range();
@@ -77,11 +83,15 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option
                 None => to_extract.syntax().text_range(),
             };
 
+            let reference_modifier = if is_mutable_reference { "&mut " } else { "" };
+
             match anchor {
                 Anchor::Before(_) | Anchor::Replace(_) => {
-                    format_to!(buf, "let {} = ", var_name)
+                    format_to!(buf, "let {} = {}", var_name, reference_modifier)
+                }
+                Anchor::WrapInBlock(_) => {
+                    format_to!(buf, "{{ let {} = {}", var_name, reference_modifier)
                 }
-                Anchor::WrapInBlock(_) => format_to!(buf, "{{ let {} = ", var_name),
             };
             format_to!(buf, "{}", to_extract.syntax());
 
@@ -146,6 +156,22 @@ fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
     }
 }
 
+fn get_receiver_type(ctx: &AssistContext, expression: &ast::Expr) -> Option<hir::Type> {
+    let receiver = get_receiver(expression.to_owned())?;
+    Some(ctx.sema.type_of_expr(&receiver)?.original())
+}
+
+fn get_receiver(expression: ast::Expr) -> Option<ast::Expr> {
+    match expression {
+        ast::Expr::FieldExpr(field) if field.expr().is_some() => {
+            let nested_expression = &field.expr()?;
+            get_receiver(nested_expression.to_owned())
+        }
+        ast::Expr::PathExpr(_) => Some(expression),
+        _ => None,
+    }
+}
+
 #[derive(Debug)]
 enum Anchor {
     Before(SyntaxNode),
@@ -900,4 +926,64 @@ fn extract_var_no_block_body() {
 ",
         );
     }
+
+    #[test]
+    fn test_extract_var_mutable_reference_parameter() {
+        check_assist(
+            extract_variable,
+            r#"
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(s: &mut S) {
+    $0s.vec$0.push(0);
+}"#,
+            r#"
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(s: &mut S) {
+    let $0var_name = &mut s.vec;
+    var_name.push(0);
+}"#,
+        );
+    }
+
+    #[test]
+    fn test_extract_var_mutable_reference_parameter_deep_nesting() {
+        check_assist(
+            extract_variable,
+            r#"
+struct Y {
+    field: X
+}
+struct X {
+    field: S
+}
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(f: &mut Y) {
+    $0f.field.field.vec$0.push(0);
+}"#,
+            r#"
+struct Y {
+    field: X
+}
+struct X {
+    field: S
+}
+struct S {
+    vec: Vec<u8>
+}
+
+fn foo(f: &mut Y) {
+    let $0var_name = &mut f.field.field.vec;
+    var_name.push(0);
+}"#,
+        );
+    }
 }