]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/convert_bool_then.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / convert_bool_then.rs
index 497e7c2546d0d0e0436d2fb1c7f579dbbc421e0e..274718e6ea90e900466640d40db4d8cdc2415af7 100644 (file)
@@ -1,11 +1,15 @@
 use hir::{known, AsAssocItem, Semantics};
 use ide_db::{
-    helpers::{for_each_tail_expr, FamousDefs},
+    helpers::{
+        for_each_tail_expr,
+        node_ext::{block_as_lone_tail, is_pattern_cond, preorder_expr},
+        FamousDefs,
+    },
     RootDatabase,
 };
 use itertools::Itertools;
 use syntax::{
-    ast::{self, edit::AstNodeEdit, make, ArgListOwner},
+    ast::{self, edit::AstNodeEdit, make, HasArgList},
     ted, AstNode, SyntaxNode,
 };
 
 // }
 // ```
 pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
-    // todo, applies to match as well
+    // FIXME applies to match as well
     let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
     if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
         return None;
     }
 
-    let cond = expr.condition().filter(|cond| !cond.is_pattern_cond())?;
-    let cond = cond.expr()?;
+    let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
     let then = expr.then_branch()?;
     let else_ = match expr.else_branch()? {
         ast::ElseBranch::Block(b) => b,
@@ -97,7 +100,29 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) ->
                 e => e,
             };
 
+            let parenthesize = matches!(
+                cond,
+                ast::Expr::BinExpr(_)
+                    | ast::Expr::BlockExpr(_)
+                    | ast::Expr::BoxExpr(_)
+                    | ast::Expr::BreakExpr(_)
+                    | ast::Expr::CastExpr(_)
+                    | ast::Expr::ClosureExpr(_)
+                    | ast::Expr::ContinueExpr(_)
+                    | ast::Expr::ForExpr(_)
+                    | ast::Expr::IfExpr(_)
+                    | ast::Expr::LoopExpr(_)
+                    | ast::Expr::MacroCall(_)
+                    | ast::Expr::MatchExpr(_)
+                    | ast::Expr::PrefixExpr(_)
+                    | ast::Expr::RangeExpr(_)
+                    | ast::Expr::RefExpr(_)
+                    | ast::Expr::ReturnExpr(_)
+                    | ast::Expr::WhileExpr(_)
+                    | ast::Expr::YieldExpr(_)
+            );
             let cond = if invert_cond { invert_boolean_expression(cond) } else { cond };
+            let cond = if parenthesize { make::expr_paren(cond) } else { cond };
             let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
             let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
             builder.replace(target, mcall.to_string());
@@ -183,7 +208,7 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext) ->
                 _ => receiver,
             };
             let if_expr = make::expr_if(
-                make::condition(cond, None),
+                cond,
                 closure_body.reset_indent(),
                 Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))),
             )
@@ -218,7 +243,7 @@ fn is_invalid_body(
     expr: &ast::Expr,
 ) -> bool {
     let mut invalid = false;
-    expr.preorder(&mut |e| {
+    preorder_expr(expr, &mut |e| {
         invalid |=
             matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
         invalid
@@ -252,7 +277,7 @@ fn block_is_none_variant(
     block: &ast::BlockExpr,
     none_variant: hir::Variant,
 ) -> bool {
-    block.as_lone_tail().and_then(|e| match e {
+    block_as_lone_tail(block).and_then(|e| match e {
         ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
             hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
             _ => None,