// fn foo() -> i32 { 42i32 }
// ```
pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
- let (tail_expr, builder_edit_pos, wrap_expr) = extract_tail(ctx)?;
+ let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
let module = ctx.sema.scope(tail_expr.syntax()).module()?;
let ty = ctx.sema.type_of_expr(&tail_expr)?;
if ty.is_unit() {
acc.add(
AssistId("infer_function_return_type", AssistKind::RefactorRewrite),
- "Add this function's return type",
+ match fn_type {
+ FnType::Function => "Add this function's return type",
+ FnType::Closure { .. } => "Add this closure's return type",
+ },
tail_expr.syntax().text_range(),
|builder| {
match builder_edit_pos {
builder.replace(text_range, &format!("-> {}", ty))
}
}
- if wrap_expr {
+ if let FnType::Closure { wrap_expr: true } = fn_type {
mark::hit!(wrap_closure_non_block_expr);
// `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
}
}
-fn extract_tail(ctx: &AssistContext) -> Option<(ast::Expr, InsertOrReplace, bool)> {
- let (tail_expr, return_type_range, action, wrap_expr) =
+enum FnType {
+ Function,
+ Closure { wrap_expr: bool },
+}
+
+fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
+ let (fn_type, tail_expr, return_type_range, action) =
if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end();
let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?;
};
let ret_range = TextRange::new(rpipe_pos, body_start);
- (tail_expr, ret_range, action, wrap_expr)
+ (FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
} else {
let func = ctx.find_node_at_offset::<ast::Fn>()?;
let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end();
let ret_range_end = body.l_curly_token()?.text_range().start();
let ret_range = TextRange::new(rparen_pos, ret_range_end);
- (tail_expr, ret_range, action, false)
+ (FnType::Function, tail_expr, ret_range, action)
};
let frange = ctx.frange.range;
if return_type_range.contains_range(frange) {
} else {
return None;
}
- Some((tail_expr, action, wrap_expr))
+ Some((fn_type, tail_expr, action))
}
#[cfg(test)]