]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/generate_function.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / generate_function.rs
index ea3656ed2148e79cb34b147219520f25ad5d40b7..ac33d56858c03e89b1c354d7b870525d4fd6af0f 100644 (file)
@@ -1,11 +1,13 @@
-use hir::{HasSource, HirDisplay, Module, ModuleDef, Semantics, TypeInfo};
+use rustc_hash::{FxHashMap, FxHashSet};
+
+use hir::{HasSource, HirDisplay, Module, Semantics, TypeInfo};
+use ide_db::helpers::FamousDefs;
 use ide_db::{
     base_db::FileId,
     defs::{Definition, NameRefClass},
     helpers::SnippetCap,
     RootDatabase,
 };
-use rustc_hash::{FxHashMap, FxHashSet};
 use stdx::to_lower_snake_case;
 use syntax::{
     ast::{
@@ -17,7 +19,7 @@
 };
 
 use crate::{
-    utils::useless_type_special_case,
+    utils::convert_reference_type,
     utils::{find_struct_impl, render_snippet, Cursor},
     AssistContext, AssistId, AssistKind, Assists,
 };
@@ -51,27 +53,6 @@ pub(crate) fn generate_function(acc: &mut Assists, ctx: &AssistContext) -> Optio
     gen_fn(acc, ctx).or_else(|| gen_method(acc, ctx))
 }
 
-enum FuncExpr {
-    Func(ast::CallExpr),
-    Method(ast::MethodCallExpr),
-}
-
-impl FuncExpr {
-    fn arg_list(&self) -> Option<ast::ArgList> {
-        match self {
-            FuncExpr::Func(fn_call) => fn_call.arg_list(),
-            FuncExpr::Method(m_call) => m_call.arg_list(),
-        }
-    }
-
-    fn syntax(&self) -> &SyntaxNode {
-        match self {
-            FuncExpr::Func(fn_call) => fn_call.syntax(),
-            FuncExpr::Method(m_call) => m_call.syntax(),
-        }
-    }
-}
-
 fn gen_fn(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let path_expr: ast::PathExpr = ctx.find_node_at_offset()?;
     let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?;
@@ -130,6 +111,10 @@ fn gen_fn(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
 
 fn gen_method(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
+    if ctx.sema.resolve_method_call(&call).is_some() {
+        return None;
+    }
+
     let fn_name = call.name_ref()?;
     let adt = ctx.sema.type_of_expr(&call.receiver()?)?.original().strip_references().as_adt()?;
 
@@ -253,7 +238,8 @@ fn from_call(
         let needs_pub = target_module.is_some();
         let target_module = target_module.or_else(|| current_module(target.syntax(), ctx))?;
         let fn_name = make::name(fn_name);
-        let (type_params, params) = fn_args(ctx, target_module, FuncExpr::Func(call.clone()))?;
+        let (type_params, params) =
+            fn_args(ctx, target_module, ast::CallableExpr::Call(call.clone()))?;
 
         let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
         let is_async = await_expr.is_some();
@@ -283,7 +269,8 @@ fn from_method_call(
         let needs_pub =
             !module_is_descendant(&current_module(call.syntax(), ctx)?, &target_module, ctx);
         let fn_name = make::name(&name.text());
-        let (type_params, params) = fn_args(ctx, target_module, FuncExpr::Method(call.clone()))?;
+        let (type_params, params) =
+            fn_args(ctx, target_module, ast::CallableExpr::MethodCall(call.clone()))?;
 
         let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
         let is_async = await_expr.is_some();
@@ -383,7 +370,7 @@ fn get_fn_target(
     target_module: &Option<Module>,
     call: CallExpr,
 ) -> Option<(GeneratedFunctionTarget, FileId, TextSize)> {
-    let mut file = ctx.frange.file_id;
+    let mut file = ctx.file_id();
     let target = match target_module {
         Some(target_module) => {
             let module_source = target_module.definition_source(ctx.db());
@@ -391,7 +378,7 @@ fn get_fn_target(
             file = in_file;
             target
         }
-        None => next_space_for_fn_after_call_site(FuncExpr::Func(call))?,
+        None => next_space_for_fn_after_call_site(ast::CallableExpr::Call(call))?,
     };
     Some((target.clone(), file, get_insert_offset(&target)))
 }
@@ -437,25 +424,13 @@ fn syntax(&self) -> &SyntaxNode {
 fn fn_args(
     ctx: &AssistContext,
     target_module: hir::Module,
-    call: FuncExpr,
+    call: ast::CallableExpr,
 ) -> Option<(Option<ast::GenericParamList>, ast::ParamList)> {
     let mut arg_names = Vec::new();
     let mut arg_types = Vec::new();
     for arg in call.arg_list()?.args() {
         arg_names.push(fn_arg_name(&ctx.sema, &arg));
-        arg_types.push(match fn_arg_type(ctx, target_module, &arg) {
-            Some(ty) => {
-                if !ty.is_empty() && ty.starts_with('&') {
-                    match useless_type_special_case("", &ty[1..].to_owned()) {
-                        Some((new_ty, _)) => new_ty,
-                        None => ty,
-                    }
-                } else {
-                    ty
-                }
-            }
-            None => String::from("_"),
-        });
+        arg_types.push(fn_arg_type(ctx, target_module, &arg));
     }
     deduplicate_arg_names(&mut arg_names);
     let params = arg_names.into_iter().zip(arg_types).map(|(name, ty)| {
@@ -466,8 +441,8 @@ fn fn_args(
         None,
         make::param_list(
             match call {
-                FuncExpr::Func(_) => None,
-                FuncExpr::Method(_) => Some(make::self_param()),
+                ast::CallableExpr::Call(_) => None,
+                ast::CallableExpr::MethodCall(_) => Some(make::self_param()),
             },
             params,
         ),
@@ -485,10 +460,10 @@ fn fn_args(
 /// assert_eq!(names, expected);
 /// ```
 fn deduplicate_arg_names(arg_names: &mut Vec<String>) {
-    let arg_name_counts = arg_names.iter().fold(FxHashMap::default(), |mut m, name| {
-        *m.entry(name).or_insert(0) += 1;
-        m
-    });
+    let mut arg_name_counts = FxHashMap::default();
+    for name in arg_names.iter() {
+        *arg_name_counts.entry(name).or_insert(0) += 1;
+    }
     let duplicate_arg_names: FxHashSet<String> = arg_name_counts
         .into_iter()
         .filter(|(_, count)| *count >= 2)
@@ -510,10 +485,14 @@ fn fn_arg_name(sema: &Semantics<RootDatabase>, arg_expr: &ast::Expr) -> String {
     let name = (|| match arg_expr {
         ast::Expr::CastExpr(cast_expr) => Some(fn_arg_name(sema, &cast_expr.expr()?)),
         expr => {
-            let name_ref = expr.syntax().descendants().filter_map(ast::NameRef::cast).last()?;
-            if let Some(NameRefClass::Definition(Definition::ModuleDef(
-                ModuleDef::Const(_) | ModuleDef::Static(_),
-            ))) = NameRefClass::classify(sema, &name_ref)
+            let name_ref = expr
+                .syntax()
+                .descendants()
+                .filter_map(ast::NameRef::cast)
+                .filter(|name| name.ident_token().is_some())
+                .last()?;
+            if let Some(NameRefClass::Definition(Definition::Const(_) | Definition::Static(_))) =
+                NameRefClass::classify(sema, &name_ref)
             {
                 return Some(name_ref.to_string().to_lowercase());
             };
@@ -530,24 +509,35 @@ fn fn_arg_name(sema: &Semantics<RootDatabase>, arg_expr: &ast::Expr) -> String {
     }
 }
 
-fn fn_arg_type(
-    ctx: &AssistContext,
-    target_module: hir::Module,
-    fn_arg: &ast::Expr,
-) -> Option<String> {
-    let ty = ctx.sema.type_of_expr(fn_arg)?.adjusted();
-    if ty.is_unknown() {
-        return None;
+fn fn_arg_type(ctx: &AssistContext, target_module: hir::Module, fn_arg: &ast::Expr) -> String {
+    fn maybe_displayed_type(
+        ctx: &AssistContext,
+        target_module: hir::Module,
+        fn_arg: &ast::Expr,
+    ) -> Option<String> {
+        let ty = ctx.sema.type_of_expr(fn_arg)?.adjusted();
+        if ty.is_unknown() {
+            return None;
+        }
+
+        if ty.is_reference() || ty.is_mutable_reference() {
+            let famous_defs = &FamousDefs(&ctx.sema, ctx.sema.scope(fn_arg.syntax()).krate());
+            convert_reference_type(ty.strip_references(), ctx.db(), famous_defs)
+                .map(|conversion| conversion.convert_type(ctx.db()))
+                .or_else(|| ty.display_source_code(ctx.db(), target_module.into()).ok())
+        } else {
+            ty.display_source_code(ctx.db(), target_module.into()).ok()
+        }
     }
 
-    ty.display_source_code(ctx.db(), target_module.into()).ok()
+    maybe_displayed_type(ctx, target_module, fn_arg).unwrap_or_else(|| String::from("_"))
 }
 
 /// Returns the position inside the current mod or file
 /// directly after the current block
 /// We want to write the generated function directly after
 /// fns, impls or macro calls, but inside mods
-fn next_space_for_fn_after_call_site(expr: FuncExpr) -> Option<GeneratedFunctionTarget> {
+fn next_space_for_fn_after_call_site(expr: ast::CallableExpr) -> Option<GeneratedFunctionTarget> {
     let mut ancestors = expr.syntax().ancestors().peekable();
     let mut last_ancestor: Option<SyntaxNode> = None;
     while let Some(next_ancestor) = ancestors.next() {
@@ -1679,7 +1669,7 @@ fn main() {
     foo(a.0);
 }
 
-fn foo(arg0: ()) ${0:-> _} {
+fn foo(a: ()) ${0:-> _} {
     todo!()
 }
 ",