]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_bool_then.rs
Merge #9814
[rust.git] / crates / ide_assists / src / handlers / convert_bool_then.rs
1 use hir::{known, Semantics};
2 use ide_db::{
3     helpers::{for_each_tail_expr, FamousDefs},
4     RootDatabase,
5 };
6 use syntax::{
7     ast::{self, make, ArgListOwner},
8     ted, AstNode, SyntaxNode,
9 };
10
11 use crate::{
12     utils::{invert_boolean_expression, unwrap_trivial_block},
13     AssistContext, AssistId, AssistKind, Assists,
14 };
15
16 // Assist: convert_if_to_bool_then
17 //
18 // Converts an if expression into a corresponding `bool::then` call.
19 //
20 // ```
21 // # //- minicore: option
22 // fn main() {
23 //     if$0 cond {
24 //         Some(val)
25 //     } else {
26 //         None
27 //     }
28 // }
29 // ```
30 // ->
31 // ```
32 // fn main() {
33 //     cond.then(|| val)
34 // }
35 // ```
36 pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
37     // todo, applies to match as well
38     let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
39     if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
40         return None;
41     }
42
43     let cond = expr.condition().filter(|cond| !cond.is_pattern_cond())?;
44     let cond = cond.expr()?;
45     let then = expr.then_branch()?;
46     let else_ = match expr.else_branch()? {
47         ast::ElseBranch::Block(b) => b,
48         ast::ElseBranch::IfExpr(_) => {
49             cov_mark::hit!(convert_if_to_bool_then_chain);
50             return None;
51         }
52     };
53
54     let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
55
56     let (invert_cond, closure_body) = match (
57         block_is_none_variant(&ctx.sema, &then, none_variant),
58         block_is_none_variant(&ctx.sema, &else_, none_variant),
59     ) {
60         (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
61         (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
62         _ => return None,
63     };
64
65     if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
66         cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
67         return None;
68     }
69
70     let target = expr.syntax().text_range();
71     acc.add(
72         AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite),
73         "Convert `if` expression to `bool::then` call",
74         target,
75         |builder| {
76             let closure_body = closure_body.clone_for_update();
77             // Rewrite all `Some(e)` in tail position to `e`
78             for_each_tail_expr(&closure_body, &mut |e| {
79                 let e = match e {
80                     ast::Expr::BreakExpr(e) => e.expr(),
81                     e @ ast::Expr::CallExpr(_) => Some(e.clone()),
82                     _ => None,
83                 };
84                 if let Some(ast::Expr::CallExpr(call)) = e {
85                     if let Some(arg_list) = call.arg_list() {
86                         if let Some(arg) = arg_list.args().next() {
87                             ted::replace(call.syntax(), arg.syntax());
88                         }
89                     }
90                 }
91             });
92             let closure_body = match closure_body {
93                 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
94                 e => e,
95             };
96
97             let cond = if invert_cond { invert_boolean_expression(&ctx.sema, cond) } else { cond };
98             let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
99             let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
100             builder.replace(target, mcall.to_string());
101         },
102     )
103 }
104
105 fn option_variants(
106     sema: &Semantics<RootDatabase>,
107     expr: &SyntaxNode,
108 ) -> Option<(hir::Variant, hir::Variant)> {
109     let fam = FamousDefs(&sema, sema.scope(expr).krate());
110     let option_variants = fam.core_option_Option()?.variants(sema.db);
111     match &*option_variants {
112         &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None {
113             (variant0, variant1)
114         } else {
115             (variant1, variant0)
116         }),
117         _ => None,
118     }
119 }
120
121 /// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
122 /// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
123 fn is_invalid_body(
124     sema: &Semantics<RootDatabase>,
125     some_variant: hir::Variant,
126     expr: &ast::Expr,
127 ) -> bool {
128     let mut invalid = false;
129     expr.preorder(&mut |e| {
130         invalid |=
131             matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
132         invalid
133     });
134     if !invalid {
135         for_each_tail_expr(&expr, &mut |e| {
136             if invalid {
137                 return;
138             }
139             let e = match e {
140                 ast::Expr::BreakExpr(e) => e.expr(),
141                 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
142                 _ => None,
143             };
144             if let Some(ast::Expr::CallExpr(call)) = e {
145                 if let Some(ast::Expr::PathExpr(p)) = call.expr() {
146                     let res = p.path().and_then(|p| sema.resolve_path(&p));
147                     if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
148                         return invalid |= v != some_variant;
149                     }
150                 }
151             }
152             invalid = true
153         });
154     }
155     invalid
156 }
157
158 fn block_is_none_variant(
159     sema: &Semantics<RootDatabase>,
160     block: &ast::BlockExpr,
161     none_variant: hir::Variant,
162 ) -> bool {
163     block.as_lone_tail().and_then(|e| match e {
164         ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
165             hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
166             _ => None,
167         },
168         _ => None,
169     }) == Some(none_variant)
170 }
171
172 #[cfg(test)]
173 mod tests {
174     use crate::tests::{check_assist, check_assist_not_applicable};
175
176     use super::*;
177
178     #[test]
179     fn convert_if_to_bool_then_simple() {
180         check_assist(
181             convert_if_to_bool_then,
182             r"
183 //- minicore:option
184 fn main() {
185     if$0 true {
186         Some(15)
187     } else {
188         None
189     }
190 }
191 ",
192             r"
193 fn main() {
194     true.then(|| 15)
195 }
196 ",
197         );
198     }
199
200     #[test]
201     fn convert_if_to_bool_then_invert() {
202         check_assist(
203             convert_if_to_bool_then,
204             r"
205 //- minicore:option
206 fn main() {
207     if$0 true {
208         None
209     } else {
210         Some(15)
211     }
212 }
213 ",
214             r"
215 fn main() {
216     false.then(|| 15)
217 }
218 ",
219         );
220     }
221
222     #[test]
223     fn convert_if_to_bool_then_none_none() {
224         check_assist_not_applicable(
225             convert_if_to_bool_then,
226             r"
227 //- minicore:option
228 fn main() {
229     if$0 true {
230         None
231     } else {
232         None
233     }
234 }
235 ",
236         );
237     }
238
239     #[test]
240     fn convert_if_to_bool_then_some_some() {
241         check_assist_not_applicable(
242             convert_if_to_bool_then,
243             r"
244 //- minicore:option
245 fn main() {
246     if$0 true {
247         Some(15)
248     } else {
249         Some(15)
250     }
251 }
252 ",
253         );
254     }
255
256     #[test]
257     fn convert_if_to_bool_then_mixed() {
258         check_assist_not_applicable(
259             convert_if_to_bool_then,
260             r"
261 //- minicore:option
262 fn main() {
263     if$0 true {
264         if true {
265             Some(15)
266         } else {
267             None
268         }
269     } else {
270         None
271     }
272 }
273 ",
274         );
275     }
276
277     #[test]
278     fn convert_if_to_bool_then_chain() {
279         cov_mark::check!(convert_if_to_bool_then_chain);
280         check_assist_not_applicable(
281             convert_if_to_bool_then,
282             r"
283 //- minicore:option
284 fn main() {
285     if$0 true {
286         Some(15)
287     } else if true {
288         None
289     } else {
290         None
291     }
292 }
293 ",
294         );
295     }
296
297     #[test]
298     fn convert_if_to_bool_then_pattern_cond() {
299         check_assist_not_applicable(
300             convert_if_to_bool_then,
301             r"
302 //- minicore:option
303 fn main() {
304     if$0 let true = true {
305         Some(15)
306     } else {
307         None
308     }
309 }
310 ",
311         );
312     }
313
314     #[test]
315     fn convert_if_to_bool_then_pattern_invalid_body() {
316         cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
317         check_assist_not_applicable(
318             convert_if_to_bool_then,
319             r"
320 //- minicore:option
321 fn make_me_an_option() -> Option<i32> { None }
322 fn main() {
323     if$0 true {
324         if true {
325             make_me_an_option()
326         } else {
327             Some(15)
328         }
329     } else {
330         None
331     }
332 }
333 ",
334         );
335         check_assist_not_applicable(
336             convert_if_to_bool_then,
337             r"
338 //- minicore:option
339 fn main() {
340     if$0 true {
341         if true {
342             return;
343         }
344         Some(15)
345     } else {
346         None
347     }
348 }
349 ",
350         );
351     }
352 }