1 use ide_db::defs::{Definition, NameRefClass};
4 ted, AstNode, SyntaxNode,
8 assist_context::{AssistContext, Assists},
12 // Assist: convert_match_to_let_else
14 // Converts let statement with match initializer to let-else statement.
17 // # //- minicore: option
18 // fn foo(opt: Option<()>) {
19 // let val = $0match opt {
27 // fn foo(opt: Option<()>) {
28 // let Some(val) = opt else { return };
31 pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
32 let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
33 let binding = find_binding(let_stmt.pat()?)?;
35 let initializer = match let_stmt.initializer() {
36 Some(ast::Expr::MatchExpr(it)) => it,
39 let initializer_expr = initializer.expr()?;
41 let (extracting_arm, diverging_arm) = match find_arms(ctx, &initializer) {
45 if extracting_arm.guard().is_some() {
46 cov_mark::hit!(extracting_arm_has_guard);
50 let diverging_arm_expr = diverging_arm.expr()?;
51 let extracting_arm_pat = extracting_arm.pat()?;
52 let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?;
55 AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
56 "Convert match to let-else",
57 let_stmt.syntax().text_range(),
59 let extracting_arm_pat = rename_variable(&extracting_arm_pat, extracted_variable, binding);
61 let_stmt.syntax().text_range(),
62 format!("let {extracting_arm_pat} = {initializer_expr} else {{ {diverging_arm_expr} }};")
68 // Given a pattern, find the name introduced to the surrounding scope.
69 fn find_binding(pat: ast::Pat) -> Option<ast::IdentPat> {
70 if let ast::Pat::IdentPat(ident) = pat {
77 // Given a match expression, find extracting and diverging arms.
79 ctx: &AssistContext<'_>,
80 match_expr: &ast::MatchExpr,
81 ) -> Option<(ast::MatchArm, ast::MatchArm)> {
82 let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
87 let mut extracting = None;
88 let mut diverging = None;
90 if ctx.sema.type_of_expr(&arm.expr().unwrap()).unwrap().original().is_never() {
91 diverging = Some(arm);
93 extracting = Some(arm);
97 match (extracting, diverging) {
98 (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
100 cov_mark::hit!(non_diverging_match);
106 // Given an extracting arm, find the extracted variable.
107 fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<ast::Name> {
109 ast::Expr::PathExpr(path) => {
110 let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
111 match NameRefClass::classify(&ctx.sema, &name_ref)? {
112 NameRefClass::Definition(Definition::Local(local)) => {
113 let source = local.source(ctx.db()).value.left()?;
120 cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
126 // Rename `extracted` with `binding` in `pat`.
127 fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::IdentPat) -> SyntaxNode {
128 let syntax = pat.syntax().clone_for_update();
129 let extracted_syntax = syntax.covering_element(extracted.syntax().text_range());
131 // If `extracted` variable is a record field, we should rename it to `binding`,
132 // otherwise we just need to replace `extracted` with `binding`.
134 if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
136 if let Some(name_ref) = record_pat_field.field_name() {
138 record_pat_field.syntax(),
139 ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding.into())
145 ted::replace(extracted_syntax, binding.syntax().clone_for_update());
153 use crate::tests::{check_assist, check_assist_not_applicable};
158 fn should_not_be_applicable_for_non_diverging_match() {
159 cov_mark::check!(non_diverging_match);
160 check_assist_not_applicable(
161 convert_match_to_let_else,
164 fn foo(opt: Option<()>) {
165 let val = $0match opt {
175 fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
176 cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
177 check_assist_not_applicable(
178 convert_match_to_let_else,
181 fn foo(opt: Option<i32>) {
182 let val = $0match opt {
190 check_assist_not_applicable(
191 convert_match_to_let_else,
194 fn foo(opt: Option<()>) {
195 let val = $0match opt {
208 fn should_not_be_applicable_if_extracting_arm_has_guard() {
209 cov_mark::check!(extracting_arm_has_guard);
210 check_assist_not_applicable(
211 convert_match_to_let_else,
214 fn foo(opt: Option<()>) {
215 let val = $0match opt {
216 Some(it) if 2 > 1 => it,
227 convert_match_to_let_else,
230 fn foo(opt: Option<()>) {
231 let val = $0match opt {
238 fn foo(opt: Option<()>) {
239 let Some(val) = opt else { return };
246 fn keeps_modifiers() {
248 convert_match_to_let_else,
251 fn foo(opt: Option<()>) {
252 let ref mut val = $0match opt {
259 fn foo(opt: Option<()>) {
260 let Some(ref mut val) = opt else { return };
267 fn nested_pattern() {
269 convert_match_to_let_else,
271 //- minicore: option, result
272 fn foo(opt: Option<Result<()>>) {
273 let val = $0match opt {
280 fn foo(opt: Option<Result<()>>) {
281 let Some(Ok(val)) = opt else { return };
288 fn works_with_any_diverging_block() {
290 convert_match_to_let_else,
293 fn foo(opt: Option<()>) {
295 let val = $0match opt {
303 fn foo(opt: Option<()>) {
305 let Some(val) = opt else { break };
312 convert_match_to_let_else,
315 fn foo(opt: Option<()>) {
317 let val = $0match opt {
325 fn foo(opt: Option<()>) {
327 let Some(val) = opt else { continue };
334 convert_match_to_let_else,
339 fn foo(opt: Option<()>) {
341 let val = $0match opt {
351 fn foo(opt: Option<()>) {
353 let Some(val) = opt else { panic() };
361 fn struct_pattern() {
363 convert_match_to_let_else,
371 fn foo(opt: Option<Point>) {
372 let val = $0match opt {
373 Some(Point { x: 0, y }) => y,
384 fn foo(opt: Option<Point>) {
385 let Some(Point { x: 0, y: val }) = opt else { return };
392 fn renames_whole_binding() {
394 convert_match_to_let_else,
397 fn foo(opt: Option<i32>) -> Option<i32> {
398 let val = $0match opt {
406 fn foo(opt: Option<i32>) -> Option<i32> {
407 let val @ Some(42) = opt else { return None };