]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_bool_then.rs
Implement `bool_then_to_if` assist
[rust.git] / crates / ide_assists / src / handlers / convert_bool_then.rs
1 use hir::{known, AsAssocItem, Semantics};
2 use ide_db::{
3     helpers::{for_each_tail_expr, FamousDefs},
4     RootDatabase,
5 };
6 use itertools::Itertools;
7 use syntax::{
8     ast::{self, edit::AstNodeEdit, make, ArgListOwner},
9     ted, AstNode, SyntaxNode,
10 };
11
12 use crate::{
13     utils::{invert_boolean_expression, unwrap_trivial_block},
14     AssistContext, AssistId, AssistKind, Assists,
15 };
16
17 // Assist: convert_if_to_bool_then
18 //
19 // Converts an if expression into a corresponding `bool::then` call.
20 //
21 // ```
22 // # //- minicore: option
23 // fn main() {
24 //     if$0 cond {
25 //         Some(val)
26 //     } else {
27 //         None
28 //     }
29 // }
30 // ```
31 // ->
32 // ```
33 // fn main() {
34 //     cond.then(|| val)
35 // }
36 // ```
37 pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
38     // todo, applies to match as well
39     let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
40     if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
41         return None;
42     }
43
44     let cond = expr.condition().filter(|cond| !cond.is_pattern_cond())?;
45     let cond = cond.expr()?;
46     let then = expr.then_branch()?;
47     let else_ = match expr.else_branch()? {
48         ast::ElseBranch::Block(b) => b,
49         ast::ElseBranch::IfExpr(_) => {
50             cov_mark::hit!(convert_if_to_bool_then_chain);
51             return None;
52         }
53     };
54
55     let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
56
57     let (invert_cond, closure_body) = match (
58         block_is_none_variant(&ctx.sema, &then, none_variant),
59         block_is_none_variant(&ctx.sema, &else_, none_variant),
60     ) {
61         (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
62         (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
63         _ => return None,
64     };
65
66     if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
67         cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
68         return None;
69     }
70
71     let target = expr.syntax().text_range();
72     acc.add(
73         AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite),
74         "Convert `if` expression to `bool::then` call",
75         target,
76         |builder| {
77             let closure_body = closure_body.clone_for_update();
78             // Rewrite all `Some(e)` in tail position to `e`
79             let mut replacements = Vec::new();
80             for_each_tail_expr(&closure_body, &mut |e| {
81                 let e = match e {
82                     ast::Expr::BreakExpr(e) => e.expr(),
83                     e @ ast::Expr::CallExpr(_) => Some(e.clone()),
84                     _ => None,
85                 };
86                 if let Some(ast::Expr::CallExpr(call)) = e {
87                     if let Some(arg_list) = call.arg_list() {
88                         if let Some(arg) = arg_list.args().next() {
89                             replacements.push((call.syntax().clone(), arg.syntax().clone()));
90                         }
91                     }
92                 }
93             });
94             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
95             let closure_body = match closure_body {
96                 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
97                 e => e,
98             };
99
100             let cond = if invert_cond { invert_boolean_expression(&ctx.sema, cond) } else { cond };
101             let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
102             let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
103             builder.replace(target, mcall.to_string());
104         },
105     )
106 }
107
108 // Assist: convert_bool_then_to_if
109 //
110 // Converts a `bool::then` method call to an equivalent if expression.
111 //
112 // ```
113 // # //- minicore: bool_impl
114 // fn main() {
115 //     (0 == 0).then$0(|| val)
116 // }
117 // ```
118 // ->
119 // ```
120 // fn main() {
121 //     if 0 == 0 {
122 //         Some(val)
123 //     } else {
124 //         None
125 //     }
126 // }
127 // ```
128 pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
129     let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
130     let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
131     let receiver = mcall.receiver()?;
132     let closure_body = mcall.arg_list()?.args().exactly_one().ok()?;
133     let closure_body = match closure_body {
134         ast::Expr::ClosureExpr(expr) => expr.body()?,
135         _ => return None,
136     };
137     // Verify this is `bool::then` that is being called.
138     let func = ctx.sema.resolve_method_call(&mcall)?;
139     if func.name(ctx.sema.db).to_string() != "then" {
140         return None;
141     }
142     let assoc = func.as_assoc_item(ctx.sema.db)?;
143     match assoc.container(ctx.sema.db) {
144         hir::AssocItemContainer::Impl(impl_) if impl_.self_ty(ctx.sema.db).is_bool() => {}
145         _ => return None,
146     }
147
148     let target = mcall.syntax().text_range();
149     acc.add(
150         AssistId("convert_bool_then_to_if", AssistKind::RefactorRewrite),
151         "Convert `bool::then` call to `if`",
152         target,
153         |builder| {
154             let closure_body = match closure_body {
155                 ast::Expr::BlockExpr(block) => block,
156                 e => make::block_expr(None, Some(e)),
157             };
158
159             let closure_body = closure_body.clone_for_update();
160             // Wrap all tails in `Some(...)`
161             let none_path = make::expr_path(make::ext::ident_path("None"));
162             let some_path = make::expr_path(make::ext::ident_path("Some"));
163             let mut replacements = Vec::new();
164             for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| {
165                 let e = match e {
166                     ast::Expr::BreakExpr(e) => e.expr(),
167                     ast::Expr::ReturnExpr(e) => e.expr(),
168                     _ => Some(e.clone()),
169                 };
170                 if let Some(expr) = e {
171                     replacements.push((
172                         expr.syntax().clone(),
173                         make::expr_call(some_path.clone(), make::arg_list(Some(expr)))
174                             .syntax()
175                             .clone_for_update(),
176                     ));
177                 }
178             });
179             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
180
181             let cond = match &receiver {
182                 ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
183                 _ => receiver,
184             };
185             let if_expr = make::expr_if(
186                 make::condition(cond, None),
187                 closure_body.reset_indent(),
188                 Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))),
189             )
190             .indent(mcall.indent_level());
191
192             builder.replace(target, if_expr.to_string());
193         },
194     )
195 }
196
197 fn option_variants(
198     sema: &Semantics<RootDatabase>,
199     expr: &SyntaxNode,
200 ) -> Option<(hir::Variant, hir::Variant)> {
201     let fam = FamousDefs(&sema, sema.scope(expr).krate());
202     let option_variants = fam.core_option_Option()?.variants(sema.db);
203     match &*option_variants {
204         &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None {
205             (variant0, variant1)
206         } else {
207             (variant1, variant0)
208         }),
209         _ => None,
210     }
211 }
212
213 /// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
214 /// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
215 fn is_invalid_body(
216     sema: &Semantics<RootDatabase>,
217     some_variant: hir::Variant,
218     expr: &ast::Expr,
219 ) -> bool {
220     let mut invalid = false;
221     expr.preorder(&mut |e| {
222         invalid |=
223             matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
224         invalid
225     });
226     if !invalid {
227         for_each_tail_expr(&expr, &mut |e| {
228             if invalid {
229                 return;
230             }
231             let e = match e {
232                 ast::Expr::BreakExpr(e) => e.expr(),
233                 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
234                 _ => None,
235             };
236             if let Some(ast::Expr::CallExpr(call)) = e {
237                 if let Some(ast::Expr::PathExpr(p)) = call.expr() {
238                     let res = p.path().and_then(|p| sema.resolve_path(&p));
239                     if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
240                         return invalid |= v != some_variant;
241                     }
242                 }
243             }
244             invalid = true
245         });
246     }
247     invalid
248 }
249
250 fn block_is_none_variant(
251     sema: &Semantics<RootDatabase>,
252     block: &ast::BlockExpr,
253     none_variant: hir::Variant,
254 ) -> bool {
255     block.as_lone_tail().and_then(|e| match e {
256         ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
257             hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
258             _ => None,
259         },
260         _ => None,
261     }) == Some(none_variant)
262 }
263
264 #[cfg(test)]
265 mod tests {
266     use crate::tests::{check_assist, check_assist_not_applicable};
267
268     use super::*;
269
270     #[test]
271     fn convert_if_to_bool_then_simple() {
272         check_assist(
273             convert_if_to_bool_then,
274             r"
275 //- minicore:option
276 fn main() {
277     if$0 true {
278         Some(15)
279     } else {
280         None
281     }
282 }
283 ",
284             r"
285 fn main() {
286     true.then(|| 15)
287 }
288 ",
289         );
290     }
291
292     #[test]
293     fn convert_if_to_bool_then_invert() {
294         check_assist(
295             convert_if_to_bool_then,
296             r"
297 //- minicore:option
298 fn main() {
299     if$0 true {
300         None
301     } else {
302         Some(15)
303     }
304 }
305 ",
306             r"
307 fn main() {
308     false.then(|| 15)
309 }
310 ",
311         );
312     }
313
314     #[test]
315     fn convert_if_to_bool_then_none_none() {
316         check_assist_not_applicable(
317             convert_if_to_bool_then,
318             r"
319 //- minicore:option
320 fn main() {
321     if$0 true {
322         None
323     } else {
324         None
325     }
326 }
327 ",
328         );
329     }
330
331     #[test]
332     fn convert_if_to_bool_then_some_some() {
333         check_assist_not_applicable(
334             convert_if_to_bool_then,
335             r"
336 //- minicore:option
337 fn main() {
338     if$0 true {
339         Some(15)
340     } else {
341         Some(15)
342     }
343 }
344 ",
345         );
346     }
347
348     #[test]
349     fn convert_if_to_bool_then_mixed() {
350         check_assist_not_applicable(
351             convert_if_to_bool_then,
352             r"
353 //- minicore:option
354 fn main() {
355     if$0 true {
356         if true {
357             Some(15)
358         } else {
359             None
360         }
361     } else {
362         None
363     }
364 }
365 ",
366         );
367     }
368
369     #[test]
370     fn convert_if_to_bool_then_chain() {
371         cov_mark::check!(convert_if_to_bool_then_chain);
372         check_assist_not_applicable(
373             convert_if_to_bool_then,
374             r"
375 //- minicore:option
376 fn main() {
377     if$0 true {
378         Some(15)
379     } else if true {
380         None
381     } else {
382         None
383     }
384 }
385 ",
386         );
387     }
388
389     #[test]
390     fn convert_if_to_bool_then_pattern_cond() {
391         check_assist_not_applicable(
392             convert_if_to_bool_then,
393             r"
394 //- minicore:option
395 fn main() {
396     if$0 let true = true {
397         Some(15)
398     } else {
399         None
400     }
401 }
402 ",
403         );
404     }
405
406     #[test]
407     fn convert_if_to_bool_then_pattern_invalid_body() {
408         cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
409         check_assist_not_applicable(
410             convert_if_to_bool_then,
411             r"
412 //- minicore:option
413 fn make_me_an_option() -> Option<i32> { None }
414 fn main() {
415     if$0 true {
416         if true {
417             make_me_an_option()
418         } else {
419             Some(15)
420         }
421     } else {
422         None
423     }
424 }
425 ",
426         );
427         check_assist_not_applicable(
428             convert_if_to_bool_then,
429             r"
430 //- minicore:option
431 fn main() {
432     if$0 true {
433         if true {
434             return;
435         }
436         Some(15)
437     } else {
438         None
439     }
440 }
441 ",
442         );
443     }
444
445     #[test]
446     fn convert_bool_then_to_if_inapplicable() {
447         check_assist_not_applicable(
448             convert_bool_then_to_if,
449             r"
450 //- minicore:bool_impl
451 fn main() {
452     0.t$0hen(|| 15);
453 }
454 ",
455         );
456         check_assist_not_applicable(
457             convert_bool_then_to_if,
458             r"
459 //- minicore:bool_impl
460 fn main() {
461     true.t$0hen(15);
462 }
463 ",
464         );
465         check_assist_not_applicable(
466             convert_bool_then_to_if,
467             r"
468 //- minicore:bool_impl
469 fn main() {
470     true.t$0hen(|| 15, 15);
471 }
472 ",
473         );
474     }
475
476     #[test]
477     fn convert_bool_then_to_if_simple() {
478         check_assist(
479             convert_bool_then_to_if,
480             r"
481 //- minicore:bool_impl
482 fn main() {
483     true.t$0hen(|| 15)
484 }
485 ",
486             r"
487 fn main() {
488     if true {
489         Some(15)
490     } else {
491         None
492     }
493 }
494 ",
495         );
496         check_assist(
497             convert_bool_then_to_if,
498             r"
499 //- minicore:bool_impl
500 fn main() {
501     true.t$0hen(|| {
502         15
503     })
504 }
505 ",
506             r"
507 fn main() {
508     if true {
509         Some(15)
510     } else {
511         None
512     }
513 }
514 ",
515         );
516     }
517
518     #[test]
519     fn convert_bool_then_to_if_tails() {
520         check_assist(
521             convert_bool_then_to_if,
522             r"
523 //- minicore:bool_impl
524 fn main() {
525     true.t$0hen(|| {
526         loop {
527             if false {
528                 break 0;
529             }
530             break 15;
531         }
532     })
533 }
534 ",
535             r"
536 fn main() {
537     if true {
538         loop {
539             if false {
540                 break Some(0);
541             }
542             break Some(15);
543         }
544     } else {
545         None
546     }
547 }
548 ",
549         );
550     }
551 }