]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/generate_function.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / generate_function.rs
index 0255e508b4fc4733451d2ebe6950fdf5b6aec042..ac33d56858c03e89b1c354d7b870525d4fd6af0f 100644 (file)
@@ -1,23 +1,25 @@
-use hir::{HasSource, HirDisplay, Module, ModuleDef, Semantics, TypeInfo};
+use rustc_hash::{FxHashMap, FxHashSet};
+
+use hir::{HasSource, HirDisplay, Module, Semantics, TypeInfo};
+use ide_db::helpers::FamousDefs;
 use ide_db::{
     base_db::FileId,
     defs::{Definition, NameRefClass},
     helpers::SnippetCap,
     RootDatabase,
 };
-use rustc_hash::{FxHashMap, FxHashSet};
 use stdx::to_lower_snake_case;
 use syntax::{
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
-        make, ArgListOwner, AstNode, CallExpr, ModuleItemOwner,
+        make, AstNode, CallExpr, HasArgList, HasModuleItem,
     },
     SyntaxKind, SyntaxNode, TextRange, TextSize,
 };
 
 use crate::{
-    utils::useless_type_special_case,
+    utils::convert_reference_type,
     utils::{find_struct_impl, render_snippet, Cursor},
     AssistContext, AssistId, AssistKind, Assists,
 };
@@ -51,27 +53,6 @@ pub(crate) fn generate_function(acc: &mut Assists, ctx: &AssistContext) -> Optio
     gen_fn(acc, ctx).or_else(|| gen_method(acc, ctx))
 }
 
-enum FuncExpr {
-    Func(ast::CallExpr),
-    Method(ast::MethodCallExpr),
-}
-
-impl FuncExpr {
-    fn arg_list(&self) -> Option<ast::ArgList> {
-        match self {
-            FuncExpr::Func(fn_call) => fn_call.arg_list(),
-            FuncExpr::Method(m_call) => m_call.arg_list(),
-        }
-    }
-
-    fn syntax(&self) -> &SyntaxNode {
-        match self {
-            FuncExpr::Func(fn_call) => fn_call.syntax(),
-            FuncExpr::Method(m_call) => m_call.syntax(),
-        }
-    }
-}
-
 fn gen_fn(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let path_expr: ast::PathExpr = ctx.find_node_at_offset()?;
     let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?;
@@ -130,6 +111,10 @@ fn gen_fn(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
 
 fn gen_method(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
+    if ctx.sema.resolve_method_call(&call).is_some() {
+        return None;
+    }
+
     let fn_name = call.name_ref()?;
     let adt = ctx.sema.type_of_expr(&call.receiver()?)?.original().strip_references().as_adt()?;
 
@@ -213,10 +198,9 @@ fn to_string(&self, cap: Option<SnippetCap>) -> String {
             Some(cap) => {
                 let cursor = if self.should_focus_return_type {
                     // Focus the return type if there is one
-                    if let Some(ref ret_type) = self.ret_type {
-                        ret_type.syntax()
-                    } else {
-                        self.tail_expr.syntax()
+                    match self.ret_type {
+                        Some(ref ret_type) => ret_type.syntax(),
+                        None => self.tail_expr.syntax(),
                     }
                 } else {
                     self.tail_expr.syntax()
@@ -254,7 +238,8 @@ fn from_call(
         let needs_pub = target_module.is_some();
         let target_module = target_module.or_else(|| current_module(target.syntax(), ctx))?;
         let fn_name = make::name(fn_name);
-        let (type_params, params) = fn_args(ctx, target_module, FuncExpr::Func(call.clone()))?;
+        let (type_params, params) =
+            fn_args(ctx, target_module, ast::CallableExpr::Call(call.clone()))?;
 
         let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
         let is_async = await_expr.is_some();
@@ -284,7 +269,8 @@ fn from_method_call(
         let needs_pub =
             !module_is_descendant(&current_module(call.syntax(), ctx)?, &target_module, ctx);
         let fn_name = make::name(&name.text());
-        let (type_params, params) = fn_args(ctx, target_module, FuncExpr::Method(call.clone()))?;
+        let (type_params, params) =
+            fn_args(ctx, target_module, ast::CallableExpr::MethodCall(call.clone()))?;
 
         let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
         let is_async = await_expr.is_some();
@@ -384,7 +370,7 @@ fn get_fn_target(
     target_module: &Option<Module>,
     call: CallExpr,
 ) -> Option<(GeneratedFunctionTarget, FileId, TextSize)> {
-    let mut file = ctx.frange.file_id;
+    let mut file = ctx.file_id();
     let target = match target_module {
         Some(target_module) => {
             let module_source = target_module.definition_source(ctx.db());
@@ -392,7 +378,7 @@ fn get_fn_target(
             file = in_file;
             target
         }
-        None => next_space_for_fn_after_call_site(FuncExpr::Func(call))?,
+        None => next_space_for_fn_after_call_site(ast::CallableExpr::Call(call))?,
     };
     Some((target.clone(), file, get_insert_offset(&target)))
 }
@@ -438,26 +424,13 @@ fn syntax(&self) -> &SyntaxNode {
 fn fn_args(
     ctx: &AssistContext,
     target_module: hir::Module,
-    call: FuncExpr,
+    call: ast::CallableExpr,
 ) -> Option<(Option<ast::GenericParamList>, ast::ParamList)> {
     let mut arg_names = Vec::new();
     let mut arg_types = Vec::new();
     for arg in call.arg_list()?.args() {
         arg_names.push(fn_arg_name(&ctx.sema, &arg));
-        arg_types.push(match fn_arg_type(ctx, target_module, &arg) {
-            Some(ty) => {
-                if !ty.is_empty() && ty.starts_with('&') {
-                    if let Some((new_ty, _)) = useless_type_special_case("", &ty[1..].to_owned()) {
-                        new_ty
-                    } else {
-                        ty
-                    }
-                } else {
-                    ty
-                }
-            }
-            None => String::from("_"),
-        });
+        arg_types.push(fn_arg_type(ctx, target_module, &arg));
     }
     deduplicate_arg_names(&mut arg_names);
     let params = arg_names.into_iter().zip(arg_types).map(|(name, ty)| {
@@ -468,8 +441,8 @@ fn fn_args(
         None,
         make::param_list(
             match call {
-                FuncExpr::Func(_) => None,
-                FuncExpr::Method(_) => Some(make::self_param()),
+                ast::CallableExpr::Call(_) => None,
+                ast::CallableExpr::MethodCall(_) => Some(make::self_param()),
             },
             params,
         ),
@@ -487,10 +460,10 @@ fn fn_args(
 /// assert_eq!(names, expected);
 /// ```
 fn deduplicate_arg_names(arg_names: &mut Vec<String>) {
-    let arg_name_counts = arg_names.iter().fold(FxHashMap::default(), |mut m, name| {
-        *m.entry(name).or_insert(0) += 1;
-        m
-    });
+    let mut arg_name_counts = FxHashMap::default();
+    for name in arg_names.iter() {
+        *arg_name_counts.entry(name).or_insert(0) += 1;
+    }
     let duplicate_arg_names: FxHashSet<String> = arg_name_counts
         .into_iter()
         .filter(|(_, count)| *count >= 2)
@@ -512,10 +485,14 @@ fn fn_arg_name(sema: &Semantics<RootDatabase>, arg_expr: &ast::Expr) -> String {
     let name = (|| match arg_expr {
         ast::Expr::CastExpr(cast_expr) => Some(fn_arg_name(sema, &cast_expr.expr()?)),
         expr => {
-            let name_ref = expr.syntax().descendants().filter_map(ast::NameRef::cast).last()?;
-            if let Some(NameRefClass::Definition(Definition::ModuleDef(
-                ModuleDef::Const(_) | ModuleDef::Static(_),
-            ))) = NameRefClass::classify(sema, &name_ref)
+            let name_ref = expr
+                .syntax()
+                .descendants()
+                .filter_map(ast::NameRef::cast)
+                .filter(|name| name.ident_token().is_some())
+                .last()?;
+            if let Some(NameRefClass::Definition(Definition::Const(_) | Definition::Static(_))) =
+                NameRefClass::classify(sema, &name_ref)
             {
                 return Some(name_ref.to_string().to_lowercase());
             };
@@ -532,28 +509,35 @@ fn fn_arg_name(sema: &Semantics<RootDatabase>, arg_expr: &ast::Expr) -> String {
     }
 }
 
-fn fn_arg_type(
-    ctx: &AssistContext,
-    target_module: hir::Module,
-    fn_arg: &ast::Expr,
-) -> Option<String> {
-    let ty = ctx.sema.type_of_expr(fn_arg)?.adjusted();
-    if ty.is_unknown() {
-        return None;
-    }
+fn fn_arg_type(ctx: &AssistContext, target_module: hir::Module, fn_arg: &ast::Expr) -> String {
+    fn maybe_displayed_type(
+        ctx: &AssistContext,
+        target_module: hir::Module,
+        fn_arg: &ast::Expr,
+    ) -> Option<String> {
+        let ty = ctx.sema.type_of_expr(fn_arg)?.adjusted();
+        if ty.is_unknown() {
+            return None;
+        }
 
-    if let Ok(rendered) = ty.display_source_code(ctx.db(), target_module.into()) {
-        Some(rendered)
-    } else {
-        None
+        if ty.is_reference() || ty.is_mutable_reference() {
+            let famous_defs = &FamousDefs(&ctx.sema, ctx.sema.scope(fn_arg.syntax()).krate());
+            convert_reference_type(ty.strip_references(), ctx.db(), famous_defs)
+                .map(|conversion| conversion.convert_type(ctx.db()))
+                .or_else(|| ty.display_source_code(ctx.db(), target_module.into()).ok())
+        } else {
+            ty.display_source_code(ctx.db(), target_module.into()).ok()
+        }
     }
+
+    maybe_displayed_type(ctx, target_module, fn_arg).unwrap_or_else(|| String::from("_"))
 }
 
 /// Returns the position inside the current mod or file
 /// directly after the current block
 /// We want to write the generated function directly after
 /// fns, impls or macro calls, but inside mods
-fn next_space_for_fn_after_call_site(expr: FuncExpr) -> Option<GeneratedFunctionTarget> {
+fn next_space_for_fn_after_call_site(expr: ast::CallableExpr) -> Option<GeneratedFunctionTarget> {
     let mut ancestors = expr.syntax().ancestors().peekable();
     let mut last_ancestor: Option<SyntaxNode> = None;
     while let Some(next_ancestor) = ancestors.next() {
@@ -579,20 +563,14 @@ fn next_space_for_fn_in_module(
 ) -> Option<(FileId, GeneratedFunctionTarget)> {
     let file = module_source.file_id.original_file(db);
     let assist_item = match &module_source.value {
-        hir::ModuleSource::SourceFile(it) => {
-            if let Some(last_item) = it.items().last() {
-                GeneratedFunctionTarget::BehindItem(last_item.syntax().clone())
-            } else {
-                GeneratedFunctionTarget::BehindItem(it.syntax().clone())
-            }
-        }
-        hir::ModuleSource::Module(it) => {
-            if let Some(last_item) = it.item_list().and_then(|it| it.items().last()) {
-                GeneratedFunctionTarget::BehindItem(last_item.syntax().clone())
-            } else {
-                GeneratedFunctionTarget::InEmptyItemList(it.item_list()?.syntax().clone())
-            }
-        }
+        hir::ModuleSource::SourceFile(it) => match it.items().last() {
+            Some(last_item) => GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()),
+            None => GeneratedFunctionTarget::BehindItem(it.syntax().clone()),
+        },
+        hir::ModuleSource::Module(it) => match it.item_list().and_then(|it| it.items().last()) {
+            Some(last_item) => GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()),
+            None => GeneratedFunctionTarget::InEmptyItemList(it.item_list()?.syntax().clone()),
+        },
         hir::ModuleSource::BlockExpr(it) => {
             if let Some(last_item) =
                 it.statements().take_while(|stmt| matches!(stmt, ast::Stmt::Item(_))).last()
@@ -1691,7 +1669,7 @@ fn main() {
     foo(a.0);
 }
 
-fn foo(arg0: ()) ${0:-> _} {
+fn foo(a: ()) ${0:-> _} {
     todo!()
 }
 ",