]> git.lizzy.rs Git - rust.git/blobdiff - src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs
:arrow_up: rust-analyzer
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / generate_function.rs
index e26c76da1891649c0a035706e5eb278aa325b99e..8b67982f9158234b91afa04bf28c691b543b2d10 100644 (file)
@@ -1,4 +1,4 @@
-use hir::{HasSource, HirDisplay, Module, Semantics, TypeInfo};
+use hir::{Adt, HasSource, HirDisplay, Module, Semantics, TypeInfo};
 use ide_db::{
     base_db::FileId,
     defs::{Definition, NameRefClass},
@@ -145,7 +145,8 @@ fn gen_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
         return None;
     }
     let (impl_, file) = get_adt_source(ctx, &adt, fn_name.text().as_str())?;
-    let (target, insert_offset) = get_method_target(ctx, &target_module, &impl_)?;
+    let (target, insert_offset) = get_method_target(ctx, &impl_, &adt)?;
+
     let function_builder =
         FunctionBuilder::from_method_call(ctx, &call, &fn_name, target_module, target)?;
     let text_range = call.syntax().text_range();
@@ -174,10 +175,11 @@ fn add_func_to_accumulator(
     label: String,
 ) -> Option<()> {
     acc.add(AssistId("generate_function", AssistKind::Generate), label, text_range, |builder| {
-        let function_template = function_builder.render();
+        let indent = IndentLevel::from_node(function_builder.target.syntax());
+        let function_template = function_builder.render(adt_name.is_some());
         let mut func = function_template.to_string(ctx.config.snippet_cap);
         if let Some(name) = adt_name {
-            func = format!("\nimpl {} {{\n{}\n}}", name, func);
+            func = format!("\n{}impl {} {{\n{}\n{}}}", indent, name, func, indent);
         }
         builder.edit_file(file);
         match ctx.config.snippet_cap {
@@ -307,7 +309,7 @@ fn from_method_call(
         })
     }
 
-    fn render(self) -> FunctionTemplate {
+    fn render(self, is_method: bool) -> FunctionTemplate {
         let placeholder_expr = make::ext::expr_todo();
         let fn_body = make::block_expr(vec![], Some(placeholder_expr));
         let visibility = if self.needs_pub { Some(make::visibility_pub_crate()) } else { None };
@@ -325,8 +327,14 @@ fn render(self) -> FunctionTemplate {
 
         match self.target {
             GeneratedFunctionTarget::BehindItem(it) => {
-                let indent = IndentLevel::from_node(&it);
-                leading_ws = format!("\n\n{}", indent);
+                let mut indent = IndentLevel::from_node(&it);
+                if is_method {
+                    indent = indent + 1;
+                    leading_ws = format!("{}", indent);
+                } else {
+                    leading_ws = format!("\n\n{}", indent);
+                }
+
                 fn_def = fn_def.indent(indent);
                 trailing_ws = String::new();
             }
@@ -411,14 +419,13 @@ fn get_fn_target(
 
 fn get_method_target(
     ctx: &AssistContext<'_>,
-    target_module: &Module,
     impl_: &Option<ast::Impl>,
+    adt: &Adt,
 ) -> Option<(GeneratedFunctionTarget, TextSize)> {
     let target = match impl_ {
         Some(impl_) => next_space_for_fn_in_impl(impl_)?,
         None => {
-            next_space_for_fn_in_module(ctx.sema.db, &target_module.definition_source(ctx.sema.db))?
-                .1
+            GeneratedFunctionTarget::BehindItem(adt.source(ctx.sema.db)?.syntax().value.clone())
         }
     };
     Some((target.clone(), get_insert_offset(&target)))
@@ -437,7 +444,7 @@ fn assoc_fn_target_info(
         return None;
     }
     let (impl_, file) = get_adt_source(ctx, &adt, fn_name)?;
-    let (target, insert_offset) = get_method_target(ctx, &module, &impl_)?;
+    let (target, insert_offset) = get_method_target(ctx, &impl_, &adt)?;
     let adt_name = if impl_.is_none() { Some(adt.name(ctx.sema.db)) } else { None };
     Some(TargetInfo::new(target_module, adt_name, target, file, insert_offset))
 }
@@ -1468,14 +1475,12 @@ fn create_method() {
 ",
             r"
 struct S;
-fn foo() {S.bar();}
 impl S {
-
-
-fn bar(&self) ${0:-> _} {
-    todo!()
-}
+    fn bar(&self) ${0:-> _} {
+        todo!()
+    }
 }
+fn foo() {S.bar();}
 ",
         )
     }
@@ -1516,14 +1521,12 @@ mod s {
             r"
 mod s {
     pub struct S;
-impl S {
-
-
-    pub(crate) fn bar(&self) ${0:-> _} {
-        todo!()
+    impl S {
+        pub(crate) fn bar(&self) ${0:-> _} {
+            todo!()
+        }
     }
 }
-}
 fn foo() {s::S.bar();}
 ",
         )
@@ -1544,18 +1547,16 @@ fn foo() {
 ",
             r"
 struct S;
+impl S {
+    fn bar(&self) ${0:-> _} {
+        todo!()
+    }
+}
 mod s {
     fn foo() {
         super::S.bar();
     }
 }
-impl S {
-
-
-fn bar(&self) ${0:-> _} {
-    todo!()
-}
-}
 
 ",
         )
@@ -1571,14 +1572,12 @@ fn create_method_with_cursor_anywhere_on_call_expresion() {
 ",
             r"
 struct S;
-fn foo() {S.bar();}
 impl S {
-
-
-fn bar(&self) ${0:-> _} {
-    todo!()
-}
+    fn bar(&self) ${0:-> _} {
+        todo!()
+    }
 }
+fn foo() {S.bar();}
 ",
         )
     }
@@ -1593,14 +1592,12 @@ fn create_static_method() {
 ",
             r"
 struct S;
-fn foo() {S::bar();}
 impl S {
-
-
-fn bar() ${0:-> _} {
-    todo!()
-}
+    fn bar() ${0:-> _} {
+        todo!()
+    }
 }
+fn foo() {S::bar();}
 ",
         )
     }
@@ -1641,14 +1638,12 @@ mod s {
             r"
 mod s {
     pub struct S;
-impl S {
-
-
-    pub(crate) fn bar() ${0:-> _} {
-        todo!()
+    impl S {
+        pub(crate) fn bar() ${0:-> _} {
+            todo!()
+        }
     }
 }
-}
 fn foo() {s::S::bar();}
 ",
         )
@@ -1664,14 +1659,12 @@ fn create_static_method_with_cursor_anywhere_on_call_expresion() {
 ",
             r"
 struct S;
-fn foo() {S::bar();}
 impl S {
-
-
-fn bar() ${0:-> _} {
-    todo!()
-}
+    fn bar() ${0:-> _} {
+        todo!()
+    }
 }
+fn foo() {S::bar();}
 ",
         )
     }
@@ -1841,15 +1834,13 @@ fn main() {
 ",
             r"
 enum Foo {}
-fn main() {
-    Foo::new();
-}
 impl Foo {
-
-
-fn new() ${0:-> _} {
-    todo!()
+    fn new() ${0:-> _} {
+        todo!()
+    }
 }
+fn main() {
+    Foo::new();
 }
 ",
         )