use hir::{HirDisplay, Local};
use ide_db::{
defs::{Definition, NameRefClass},
- search::SearchScope,
+ search::{ReferenceAccess, SearchScope},
};
use itertools::Itertools;
use stdx::format_to;
use syntax::{
+ algo::SyntaxRewriter,
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
- AstNode, NameOwner,
+ AstNode,
},
Direction, SyntaxElement,
SyntaxKind::{self, BLOCK_EXPR, BREAK_EXPR, COMMENT, PATH_EXPR, RETURN_EXPR},
- SyntaxNode, TextRange,
+ SyntaxNode, TextRange, T,
};
use test_utils::mark;
let mut self_param = None;
let param_pats: Vec<_> = vars_used_in_body
.iter()
- .map(|node| node.source(ctx.db()))
- .filter(|src| {
+ .map(|node| (node, node.source(ctx.db())))
+ .filter(|(_, src)| {
src.file_id.original_file(ctx.db()) == ctx.frange.file_id
&& !body.contains_node(&either_syntax(&src.value))
})
- .filter_map(|src| match src.value {
- Either::Left(pat) => Some(pat),
+ .filter_map(|(&node, src)| match src.value {
+ Either::Left(_) => Some(node),
Either::Right(it) => {
// we filter self param, as there can only be one
- self_param = Some(it);
+ self_param = Some((node, it));
None
}
})
let vars_defined_in_body = vars_defined_in_body(&body, ctx);
- let vars_in_body_used_afterwards: Vec<_> = vars_defined_in_body
+ let vars_defined_in_body_and_outlive: Vec<_> = vars_defined_in_body
.iter()
.copied()
.filter(|node| {
})
.collect();
- let params = param_pats
+ let params: Vec<_> = param_pats
.into_iter()
- .map(|pat| {
- let name = pat.name().unwrap().to_string();
-
- let ty = ctx
- .sema
- .type_of_pat(&pat.into())
- .and_then(|ty| ty.display_source_code(ctx.db(), module.into()).ok())
- .unwrap_or_else(|| "()".to_string());
+ .map(|node| {
+ let usages = Definition::Local(node)
+ .usages(&ctx.sema)
+ .in_scope(SearchScope::single_file(ctx.frange.file_id))
+ .all();
- Param { name, ty }
+ let has_usages_afterwards = usages
+ .iter()
+ .flat_map(|(_, rs)| rs.iter())
+ .any(|reference| body.preceedes_range(reference.range));
+ let has_mut_inside_body = usages
+ .iter()
+ .flat_map(|(_, rs)| rs.iter())
+ .filter(|reference| body.contains_range(reference.range))
+ .any(|reference| reference.access == Some(ReferenceAccess::Write));
+
+ Param { node, has_usages_afterwards, has_mut_inside_body, is_copy: true }
})
- .collect::<Vec<_>>();
+ .collect();
let expr = body.tail_expr();
let ret_ty = match expr {
};
let has_unit_ret = ret_ty.as_ref().map_or(true, |it| it.is_unit());
- if stdx::never!(!vars_in_body_used_afterwards.is_empty() && !has_unit_ret) {
+ if stdx::never!(!vars_defined_in_body_and_outlive.is_empty() && !has_unit_ret) {
// We should not have variables that outlive body if we have expression block
return None;
}
move |builder| {
let fun = Function {
name: "fun_name".to_string(),
- self_param,
+ self_param: self_param.map(|(_, pat)| pat),
params,
ret_ty,
body,
- vars_in_body_used_afterwards,
+ vars_defined_in_body_and_outlive,
};
builder.replace(target_range, format_replacement(ctx, &fun));
fn format_replacement(ctx: &AssistContext, fun: &Function) -> String {
let mut buf = String::new();
- match fun.vars_in_body_used_afterwards.len() {
- 0 => {}
- 1 => format_to!(
- buf,
- "let {} = ",
- fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap()
- ),
- _ => {
+ match fun.vars_defined_in_body_and_outlive.as_slice() {
+ [] => {}
+ [var] => format_to!(buf, "let {} = ", var.name(ctx.db()).unwrap()),
+ [v0, vs @ ..] => {
buf.push_str("let (");
- format_to!(buf, "{}", fun.vars_in_body_used_afterwards[0].name(ctx.db()).unwrap());
- for local in fun.vars_in_body_used_afterwards.iter().skip(1) {
+ format_to!(buf, "{}", v0.name(ctx.db()).unwrap());
+ for local in vs {
format_to!(buf, ", {}", local.name(ctx.db()).unwrap());
}
buf.push_str(") = ");
{
let mut it = fun.params.iter();
if let Some(param) = it.next() {
- format_to!(buf, "{}", param.name);
+ format_to!(buf, "{}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
}
for param in it {
- format_to!(buf, ", {}", param.name);
+ format_to!(buf, ", {}{}", param.value_prefix(), param.node.name(ctx.db()).unwrap());
}
}
format_to!(buf, ")");
params: Vec<Param>,
ret_ty: Option<hir::Type>,
body: FunctionBody,
- vars_in_body_used_afterwards: Vec<Local>,
+ vars_defined_in_body_and_outlive: Vec<Local>,
}
impl Function {
#[derive(Debug)]
struct Param {
- name: String,
- ty: String,
+ node: Local,
+ has_usages_afterwards: bool,
+ has_mut_inside_body: bool,
+ is_copy: bool,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum ParamKind {
+ Value,
+ MutValue,
+ SharedRef,
+ MutRef,
+}
+
+impl ParamKind {
+ fn is_ref(&self) -> bool {
+ matches!(self, ParamKind::SharedRef | ParamKind::MutRef)
+ }
+}
+
+impl Param {
+ fn kind(&self) -> ParamKind {
+ match (self.has_usages_afterwards, self.has_mut_inside_body, self.is_copy) {
+ (true, true, _) => ParamKind::MutRef,
+ (true, false, false) => ParamKind::SharedRef,
+ (false, true, _) => ParamKind::MutValue,
+ (true, false, true) | (false, false, _) => ParamKind::Value,
+ }
+ }
+
+ fn value_prefix(&self) -> &'static str {
+ match self.kind() {
+ ParamKind::Value => "",
+ ParamKind::MutValue => "",
+ ParamKind::SharedRef => "&",
+ ParamKind::MutRef => "&mut ",
+ }
+ }
+
+ fn type_prefix(&self) -> &'static str {
+ match self.kind() {
+ ParamKind::Value => "",
+ ParamKind::MutValue => "",
+ ParamKind::SharedRef => "&",
+ ParamKind::MutRef => "&mut ",
+ }
+ }
+
+ fn mut_pattern(&self) -> &'static str {
+ match self.kind() {
+ ParamKind::MutValue => "mut ",
+ _ => "",
+ }
+ }
}
fn format_function(
if let Some(self_param) = &fun.self_param {
format_to!(fn_def, "{}", self_param);
} else if let Some(param) = it.next() {
- format_to!(fn_def, "{}: {}", param.name, param.ty);
+ format_to!(
+ fn_def,
+ "{}{}: {}{}",
+ param.mut_pattern(),
+ param.node.name(ctx.db()).unwrap(),
+ param.type_prefix(),
+ format_type(¶m.node.ty(ctx.db()), ctx, module)
+ );
}
for param in it {
- format_to!(fn_def, ", {}: {}", param.name, param.ty);
+ format_to!(
+ fn_def,
+ ", {}{}: {}{}",
+ param.mut_pattern(),
+ param.node.name(ctx.db()).unwrap(),
+ param.type_prefix(),
+ format_type(¶m.node.ty(ctx.db()), ctx, module)
+ );
}
}
format_to!(fn_def, " -> {}", format_type(ty, ctx, module));
}
} else {
- match fun.vars_in_body_used_afterwards.as_slice() {
+ match fun.vars_defined_in_body_and_outlive.as_slice() {
[] => {}
[var] => {
format_to!(fn_def, " -> {}", format_type(&var.ty(ctx.db()), ctx, module));
FunctionBody::Expr(expr) => {
fn_def.push('\n');
let expr = expr.indent(indent);
- format_to!(fn_def, "{}{}", indent + 1, expr.syntax());
+ let expr = fix_param_usages(ctx, &fun.params, expr.syntax());
+ format_to!(fn_def, "{}{}", indent + 1, expr);
fn_def.push('\n');
}
FunctionBody::Span { elements, leading_indent } => {
format_to!(fn_def, "{}", leading_indent);
- for e in elements {
- format_to!(fn_def, "{}", e);
+ for element in elements {
+ match element {
+ syntax::NodeOrToken::Node(node) => {
+ format_to!(fn_def, "{}", fix_param_usages(ctx, &fun.params, node));
+ }
+ syntax::NodeOrToken::Token(token) => {
+ format_to!(fn_def, "{}", token);
+ }
+ }
}
if !fn_def.ends_with('\n') {
fn_def.push('\n');
}
}
- match fun.vars_in_body_used_afterwards.as_slice() {
+ match fun.vars_defined_in_body_and_outlive.as_slice() {
[] => {}
[var] => format_to!(fn_def, "{}{}\n", indent + 1, var.name(ctx.db()).unwrap()),
[v0, vs @ ..] => {
ty.display_source_code(ctx.db(), module.into()).ok().unwrap_or_else(|| "()".to_string())
}
+fn fix_param_usages(ctx: &AssistContext, params: &[Param], syntax: &SyntaxNode) -> SyntaxNode {
+ let mut rewriter = SyntaxRewriter::default();
+ for param in params {
+ if !param.kind().is_ref() {
+ continue;
+ }
+
+ let usages = Definition::Local(param.node)
+ .usages(&ctx.sema)
+ .in_scope(SearchScope::single_file(ctx.frange.file_id))
+ .all();
+ let usages = usages
+ .iter()
+ .flat_map(|(_, rs)| rs.iter())
+ .filter(|reference| syntax.text_range().contains_range(reference.range));
+ for reference in usages {
+ let token = match syntax.token_at_offset(reference.range.start()).right_biased() {
+ Some(a) => a,
+ None => {
+ stdx::never!(false, "cannot find token at variable usage: {:?}", reference);
+ continue;
+ }
+ };
+ let path = match token.ancestors().find_map(ast::Expr::cast) {
+ Some(n) => n,
+ None => {
+ stdx::never!(false, "cannot find path parent of variable usage: {:?}", token);
+ continue;
+ }
+ };
+ stdx::always!(matches!(path, ast::Expr::PathExpr(_)));
+ match path.syntax().ancestors().skip(1).find_map(ast::Expr::cast) {
+ Some(ast::Expr::MethodCallExpr(_)) => {
+ // do nothing
+ }
+ Some(ast::Expr::RefExpr(node))
+ if param.kind() == ParamKind::MutRef && node.mut_token().is_some() =>
+ {
+ rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
+ }
+ Some(ast::Expr::RefExpr(node))
+ if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() =>
+ {
+ rewriter.replace_ast(&node.clone().into(), &node.expr().unwrap());
+ }
+ Some(_) | None => {
+ rewriter.replace_ast(&path, &ast::make::expr_prefix(T![*], path.clone()));
+ }
+ };
+ }
+ }
+
+ rewriter.rewrite(syntax)
+}
+
#[derive(Debug)]
enum FunctionBody {
Expr(ast::Expr),
let k = n * n;
let m = k + 2;
(k, m)
+}",
+ );
+ }
+
+ #[test]
+ fn mut_var_from_outer_scope() {
+ check_assist(
+ extract_function,
+ r"
+fn foo() {
+ let mut n = 1;
+ $0n += 1;$0
+ let m = n + 1;
+}",
+ r"
+fn foo() {
+ let mut n = 1;
+ fun_name(&mut n);
+ let m = n + 1;
+}
+
+fn $0fun_name(n: &mut i32) {
+ *n += 1;
+}",
+ );
+ }
+
+ #[test]
+ fn mut_param_many_usages_stmt() {
+ check_assist(
+ extract_function,
+ r"
+fn bar(k: i32) {}
+trait I: Copy {
+ fn succ(&self) -> Self;
+ fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+ fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+ let mut n = 1;
+ $0n += n;
+ bar(n);
+ bar(n+1);
+ bar(n*n);
+ bar(&n);
+ n.inc();
+ let v = &mut n;
+ *v = v.succ();
+ n.succ();$0
+ let m = n + 1;
+}",
+ r"
+fn bar(k: i32) {}
+trait I: Copy {
+ fn succ(&self) -> Self;
+ fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+ fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+ let mut n = 1;
+ fun_name(&mut n);
+ let m = n + 1;
+}
+
+fn $0fun_name(n: &mut i32) {
+ *n += *n;
+ bar(*n);
+ bar(*n+1);
+ bar(*n**n);
+ bar(&*n);
+ n.inc();
+ let v = n;
+ *v = v.succ();
+ n.succ();
+}",
+ );
+ }
+
+ #[test]
+ fn mut_param_many_usages_expr() {
+ check_assist(
+ extract_function,
+ r"
+fn bar(k: i32) {}
+trait I: Copy {
+ fn succ(&self) -> Self;
+ fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+ fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+ let mut n = 1;
+ $0{
+ n += n;
+ bar(n);
+ bar(n+1);
+ bar(n*n);
+ bar(&n);
+ n.inc();
+ let v = &mut n;
+ *v = v.succ();
+ n.succ();
+ }$0
+ let m = n + 1;
+}",
+ r"
+fn bar(k: i32) {}
+trait I: Copy {
+ fn succ(&self) -> Self;
+ fn inc(&mut self) -> Self { let v = self.succ(); *self = v; v }
+}
+impl I for i32 {
+ fn succ(&self) -> Self { *self + 1 }
+}
+fn foo() {
+ let mut n = 1;
+ fun_name(&mut n);
+ let m = n + 1;
+}
+
+fn $0fun_name(n: &mut i32) {
+ {
+ *n += *n;
+ bar(*n);
+ bar(*n+1);
+ bar(*n**n);
+ bar(&*n);
+ n.inc();
+ let v = n;
+ *v = v.succ();
+ n.succ();
+ }
+}",
+ );
+ }
+
+ #[test]
+ fn mut_param_by_value() {
+ check_assist(
+ extract_function,
+ r"
+fn foo() {
+ let mut n = 1;
+ $0n += 1;$0
+}",
+ r"
+fn foo() {
+ let mut n = 1;
+ fun_name(n);
+}
+
+fn $0fun_name(mut n: i32) {
+ n += 1;
}",
);
}