]> git.lizzy.rs Git - rust.git/commitdiff
Use strongly-typed ast building for early-return assist
authorAleksey Kladov <aleksey.kladov@gmail.com>
Wed, 13 Nov 2019 08:40:51 +0000 (11:40 +0300)
committerAleksey Kladov <aleksey.kladov@gmail.com>
Wed, 13 Nov 2019 08:54:21 +0000 (11:54 +0300)
crates/ra_assists/src/assists/early_return.rs
crates/ra_syntax/src/ast/make.rs

index 8507a60fb9a5c629e759596b3f115fffae1bb336..26441252662b97c015386bf040f7ebbf4d102918 100644 (file)
@@ -1,4 +1,4 @@
-use std::ops::RangeInclusive;
+use std::{iter::once, ops::RangeInclusive};
 
 use hir::db::HirDatabase;
 use ra_syntax::{
@@ -45,19 +45,22 @@ pub(crate) fn convert_to_guarded_return(ctx: AssistCtx<impl HirDatabase>) -> Opt
     let cond = if_expr.condition()?;
 
     // Check if there is an IfLet that we can handle.
-    let bound_ident = match cond.pat() {
+    let if_let_pat = match cond.pat() {
         None => None, // No IfLet, supported.
         Some(TupleStructPat(pat)) if pat.args().count() == 1 => {
             let path = pat.path()?;
             match path.qualifier() {
-                None => Some(path.segment()?.name_ref()?),
+                None => {
+                    let bound_ident = pat.args().next().unwrap();
+                    Some((path, bound_ident))
+                }
                 Some(_) => return None,
             }
         }
         Some(_) => return None, // Unsupported IfLet.
     };
 
-    let expr = cond.expr()?;
+    let cond_expr = cond.expr()?;
     let then_block = if_expr.then_branch()?.block()?;
 
     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::Block::cast)?;
@@ -79,11 +82,11 @@ pub(crate) fn convert_to_guarded_return(ctx: AssistCtx<impl HirDatabase>) -> Opt
 
     let parent_container = parent_block.syntax().parent()?.parent()?;
 
-    let early_expression = match parent_container.kind() {
-        WHILE_EXPR | LOOP_EXPR => Some("continue"),
-        FN_DEF => Some("return"),
-        _ => None,
-    }?;
+    let early_expression: ast::Expr = match parent_container.kind() {
+        WHILE_EXPR | LOOP_EXPR => make::expr_continue().into(),
+        FN_DEF => make::expr_return().into(),
+        _ => return None,
+    };
 
     if then_block.syntax().first_child_or_token().map(|t| t.kind() == L_CURLY).is_none() {
         return None;
@@ -94,22 +97,43 @@ pub(crate) fn convert_to_guarded_return(ctx: AssistCtx<impl HirDatabase>) -> Opt
 
     ctx.add_assist(AssistId("convert_to_guarded_return"), "convert to guarded return", |edit| {
         let if_indent_level = IndentLevel::from_node(&if_expr.syntax());
-        let new_block = match bound_ident {
+        let new_block = match if_let_pat {
             None => {
                 // If.
-                let early_expression = &(early_expression.to_owned() + ";");
-                let new_expr =
-                    if_indent_level.increase_indent(make::if_expression(&expr, early_expression));
+                let early_expression = &(early_expression.syntax().to_string() + ";");
+                let new_expr = if_indent_level
+                    .increase_indent(make::if_expression(&cond_expr, early_expression));
                 replace(new_expr.syntax(), &then_block, &parent_block, &if_expr)
             }
-            Some(bound_ident) => {
+            Some((path, bound_ident)) => {
                 // If-let.
-                let new_expr = if_indent_level.increase_indent(make::let_match_early(
-                    expr,
-                    &bound_ident.syntax().to_string(),
-                    early_expression,
-                ));
-                replace(new_expr.syntax(), &then_block, &parent_block, &if_expr)
+                let match_expr = {
+                    let happy_arm = make::match_arm(
+                        once(
+                            make::tuple_struct_pat(
+                                path,
+                                once(make::bind_pat(make::name("it")).into()),
+                            )
+                            .into(),
+                        ),
+                        make::expr_path(make::path_from_name_ref(make::name_ref("it"))).into(),
+                    );
+
+                    let sad_arm = make::match_arm(
+                        // FIXME: would be cool to use `None` or `Err(_)` if appropriate
+                        once(make::placeholder_pat().into()),
+                        early_expression.into(),
+                    );
+
+                    make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm]))
+                };
+
+                let let_stmt = make::let_stmt(
+                    make::bind_pat(make::name(&bound_ident.syntax().to_string())).into(),
+                    Some(match_expr.into()),
+                );
+                let let_stmt = if_indent_level.increase_indent(let_stmt);
+                replace(let_stmt.syntax(), &then_block, &parent_block, &if_expr)
             }
         };
         edit.target(if_expr.syntax().text_range());
@@ -205,7 +229,7 @@ fn main(n: Option<String>) {
                 bar();
                 le<|>t n = match n {
                     Some(it) => it,
-                    None => return,
+                    _ => return,
                 };
                 foo(n);
 
@@ -216,6 +240,29 @@ fn main(n: Option<String>) {
         );
     }
 
+    #[test]
+    fn convert_if_let_result() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+            fn main() {
+                if<|> let Ok(x) = Err(92) {
+                    foo(x);
+                }
+            }
+            "#,
+            r#"
+            fn main() {
+                le<|>t x = match Err(92) {
+                    Ok(it) => it,
+                    _ => return,
+                };
+                foo(x);
+            }
+            "#,
+        );
+    }
+
     #[test]
     fn convert_let_ok_inside_fn() {
         check_assist(
@@ -236,7 +283,7 @@ fn main(n: Option<String>) {
                 bar();
                 le<|>t n = match n {
                     Ok(it) => it,
-                    None => return,
+                    _ => return,
                 };
                 foo(n);
 
@@ -294,7 +341,7 @@ fn main() {
                 while true {
                     le<|>t n = match n {
                         Some(it) => it,
-                        None => continue,
+                        _ => continue,
                     };
                     foo(n);
                     bar();
@@ -351,7 +398,7 @@ fn main() {
                 loop {
                     le<|>t n = match n {
                         Some(it) => it,
-                        None => continue,
+                        _ => continue,
                     };
                     foo(n);
                     bar();
index 95062ef6c467a92829924e02d6563540a6cb616f..6c903ca641f207df1cf551cb9263e28cd2cbbbb4 100644 (file)
@@ -4,6 +4,10 @@
 
 use crate::{ast, AstNode, SourceFile};
 
+pub fn name(text: &str) -> ast::Name {
+    ast_from_text(&format!("mod {};", text))
+}
+
 pub fn name_ref(text: &str) -> ast::NameRef {
     ast_from_text(&format!("fn f() {{ {}; }}", text))
 }
@@ -43,6 +47,21 @@ pub fn expr_unit() -> ast::Expr {
 pub fn expr_unimplemented() -> ast::Expr {
     expr_from_text("unimplemented!()")
 }
+pub fn expr_path(path: ast::Path) -> ast::Expr {
+    expr_from_text(&path.syntax().to_string())
+}
+pub fn expr_continue() -> ast::Expr {
+    expr_from_text("continue")
+}
+pub fn expr_break() -> ast::Expr {
+    expr_from_text("break")
+}
+pub fn expr_return() -> ast::Expr {
+    expr_from_text("return")
+}
+pub fn expr_match(expr: ast::Expr, match_arm_list: ast::MatchArmList) -> ast::Expr {
+    expr_from_text(&format!("match {} {}", expr.syntax(), match_arm_list.syntax()))
+}
 fn expr_from_text(text: &str) -> ast::Expr {
     ast_from_text(&format!("const C: () = {};", text))
 }
@@ -92,8 +111,8 @@ fn from_text(text: &str) -> ast::PathPat {
     }
 }
 
-pub fn match_arm(pats: impl Iterator<Item = ast::Pat>, expr: ast::Expr) -> ast::MatchArm {
-    let pats_str = pats.map(|p| p.syntax().to_string()).join(" | ");
+pub fn match_arm(pats: impl IntoIterator<Item = ast::Pat>, expr: ast::Expr) -> ast::MatchArm {
+    let pats_str = pats.into_iter().map(|p| p.syntax().to_string()).join(" | ");
     return from_text(&format!("{} => {}", pats_str, expr.syntax()));
 
     fn from_text(text: &str) -> ast::MatchArm {
@@ -101,8 +120,8 @@ fn from_text(text: &str) -> ast::MatchArm {
     }
 }
 
-pub fn match_arm_list(arms: impl Iterator<Item = ast::MatchArm>) -> ast::MatchArmList {
-    let arms_str = arms.map(|arm| format!("\n    {}", arm.syntax())).join(",");
+pub fn match_arm_list(arms: impl IntoIterator<Item = ast::MatchArm>) -> ast::MatchArmList {
+    let arms_str = arms.into_iter().map(|arm| format!("\n    {}", arm.syntax())).join(",");
     return from_text(&format!("{},\n", arms_str));
 
     fn from_text(text: &str) -> ast::MatchArmList {
@@ -110,23 +129,6 @@ fn from_text(text: &str) -> ast::MatchArmList {
     }
 }
 
-pub fn let_match_early(expr: ast::Expr, path: &str, early_expression: &str) -> ast::LetStmt {
-    return from_text(&format!(
-        r#"let {} = match {} {{
-    {}(it) => it,
-    None => {},
-}};"#,
-        expr.syntax().text(),
-        expr.syntax().text(),
-        path,
-        early_expression
-    ));
-
-    fn from_text(text: &str) -> ast::LetStmt {
-        ast_from_text(&format!("fn f() {{ {} }}", text))
-    }
-}
-
 pub fn where_pred(path: ast::Path, bounds: impl Iterator<Item = ast::TypeBound>) -> ast::WherePred {
     let bounds = bounds.map(|b| b.syntax().to_string()).join(" + ");
     return from_text(&format!("{}: {}", path.syntax(), bounds));
@@ -153,6 +155,14 @@ pub fn if_expression(condition: &ast::Expr, statement: &str) -> ast::IfExpr {
     ))
 }
 
+pub fn let_stmt(pattern: ast::Pat, initializer: Option<ast::Expr>) -> ast::LetStmt {
+    let text = match initializer {
+        Some(it) => format!("let {} = {};", pattern.syntax(), it.syntax()),
+        None => format!("let {};", pattern.syntax()),
+    };
+    ast_from_text(&format!("fn f() {{ {} }}", text))
+}
+
 fn ast_from_text<N: AstNode>(text: &str) -> N {
     let parse = SourceFile::parse(text);
     let res = parse.tree().syntax().descendants().find_map(N::cast).unwrap();