]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_bool_then.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / convert_bool_then.rs
1 use hir::{known, AsAssocItem, Semantics};
2 use ide_db::{
3     helpers::{
4         for_each_tail_expr,
5         node_ext::{block_as_lone_tail, is_pattern_cond, preorder_expr},
6         FamousDefs,
7     },
8     RootDatabase,
9 };
10 use itertools::Itertools;
11 use syntax::{
12     ast::{self, edit::AstNodeEdit, make, HasArgList},
13     ted, AstNode, SyntaxNode,
14 };
15
16 use crate::{
17     utils::{invert_boolean_expression, unwrap_trivial_block},
18     AssistContext, AssistId, AssistKind, Assists,
19 };
20
21 // Assist: convert_if_to_bool_then
22 //
23 // Converts an if expression into a corresponding `bool::then` call.
24 //
25 // ```
26 // # //- minicore: option
27 // fn main() {
28 //     if$0 cond {
29 //         Some(val)
30 //     } else {
31 //         None
32 //     }
33 // }
34 // ```
35 // ->
36 // ```
37 // fn main() {
38 //     cond.then(|| val)
39 // }
40 // ```
41 pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
42     // FIXME applies to match as well
43     let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
44     if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
45         return None;
46     }
47
48     let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
49     let then = expr.then_branch()?;
50     let else_ = match expr.else_branch()? {
51         ast::ElseBranch::Block(b) => b,
52         ast::ElseBranch::IfExpr(_) => {
53             cov_mark::hit!(convert_if_to_bool_then_chain);
54             return None;
55         }
56     };
57
58     let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
59
60     let (invert_cond, closure_body) = match (
61         block_is_none_variant(&ctx.sema, &then, none_variant),
62         block_is_none_variant(&ctx.sema, &else_, none_variant),
63     ) {
64         (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
65         (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
66         _ => return None,
67     };
68
69     if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
70         cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
71         return None;
72     }
73
74     let target = expr.syntax().text_range();
75     acc.add(
76         AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite),
77         "Convert `if` expression to `bool::then` call",
78         target,
79         |builder| {
80             let closure_body = closure_body.clone_for_update();
81             // Rewrite all `Some(e)` in tail position to `e`
82             let mut replacements = Vec::new();
83             for_each_tail_expr(&closure_body, &mut |e| {
84                 let e = match e {
85                     ast::Expr::BreakExpr(e) => e.expr(),
86                     e @ ast::Expr::CallExpr(_) => Some(e.clone()),
87                     _ => None,
88                 };
89                 if let Some(ast::Expr::CallExpr(call)) = e {
90                     if let Some(arg_list) = call.arg_list() {
91                         if let Some(arg) = arg_list.args().next() {
92                             replacements.push((call.syntax().clone(), arg.syntax().clone()));
93                         }
94                     }
95                 }
96             });
97             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
98             let closure_body = match closure_body {
99                 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
100                 e => e,
101             };
102
103             let parenthesize = matches!(
104                 cond,
105                 ast::Expr::BinExpr(_)
106                     | ast::Expr::BlockExpr(_)
107                     | ast::Expr::BoxExpr(_)
108                     | ast::Expr::BreakExpr(_)
109                     | ast::Expr::CastExpr(_)
110                     | ast::Expr::ClosureExpr(_)
111                     | ast::Expr::ContinueExpr(_)
112                     | ast::Expr::ForExpr(_)
113                     | ast::Expr::IfExpr(_)
114                     | ast::Expr::LoopExpr(_)
115                     | ast::Expr::MacroCall(_)
116                     | ast::Expr::MatchExpr(_)
117                     | ast::Expr::PrefixExpr(_)
118                     | ast::Expr::RangeExpr(_)
119                     | ast::Expr::RefExpr(_)
120                     | ast::Expr::ReturnExpr(_)
121                     | ast::Expr::WhileExpr(_)
122                     | ast::Expr::YieldExpr(_)
123             );
124             let cond = if invert_cond { invert_boolean_expression(cond) } else { cond };
125             let cond = if parenthesize { make::expr_paren(cond) } else { cond };
126             let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
127             let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
128             builder.replace(target, mcall.to_string());
129         },
130     )
131 }
132
133 // Assist: convert_bool_then_to_if
134 //
135 // Converts a `bool::then` method call to an equivalent if expression.
136 //
137 // ```
138 // # //- minicore: bool_impl
139 // fn main() {
140 //     (0 == 0).then$0(|| val)
141 // }
142 // ```
143 // ->
144 // ```
145 // fn main() {
146 //     if 0 == 0 {
147 //         Some(val)
148 //     } else {
149 //         None
150 //     }
151 // }
152 // ```
153 pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
154     let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
155     let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
156     let receiver = mcall.receiver()?;
157     let closure_body = mcall.arg_list()?.args().exactly_one().ok()?;
158     let closure_body = match closure_body {
159         ast::Expr::ClosureExpr(expr) => expr.body()?,
160         _ => return None,
161     };
162     // Verify this is `bool::then` that is being called.
163     let func = ctx.sema.resolve_method_call(&mcall)?;
164     if func.name(ctx.sema.db).to_string() != "then" {
165         return None;
166     }
167     let assoc = func.as_assoc_item(ctx.sema.db)?;
168     match assoc.container(ctx.sema.db) {
169         hir::AssocItemContainer::Impl(impl_) if impl_.self_ty(ctx.sema.db).is_bool() => {}
170         _ => return None,
171     }
172
173     let target = mcall.syntax().text_range();
174     acc.add(
175         AssistId("convert_bool_then_to_if", AssistKind::RefactorRewrite),
176         "Convert `bool::then` call to `if`",
177         target,
178         |builder| {
179             let closure_body = match closure_body {
180                 ast::Expr::BlockExpr(block) => block,
181                 e => make::block_expr(None, Some(e)),
182             };
183
184             let closure_body = closure_body.clone_for_update();
185             // Wrap all tails in `Some(...)`
186             let none_path = make::expr_path(make::ext::ident_path("None"));
187             let some_path = make::expr_path(make::ext::ident_path("Some"));
188             let mut replacements = Vec::new();
189             for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| {
190                 let e = match e {
191                     ast::Expr::BreakExpr(e) => e.expr(),
192                     ast::Expr::ReturnExpr(e) => e.expr(),
193                     _ => Some(e.clone()),
194                 };
195                 if let Some(expr) = e {
196                     replacements.push((
197                         expr.syntax().clone(),
198                         make::expr_call(some_path.clone(), make::arg_list(Some(expr)))
199                             .syntax()
200                             .clone_for_update(),
201                     ));
202                 }
203             });
204             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
205
206             let cond = match &receiver {
207                 ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
208                 _ => receiver,
209             };
210             let if_expr = make::expr_if(
211                 cond,
212                 closure_body.reset_indent(),
213                 Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))),
214             )
215             .indent(mcall.indent_level());
216
217             builder.replace(target, if_expr.to_string());
218         },
219     )
220 }
221
222 fn option_variants(
223     sema: &Semantics<RootDatabase>,
224     expr: &SyntaxNode,
225 ) -> Option<(hir::Variant, hir::Variant)> {
226     let fam = FamousDefs(sema, sema.scope(expr).krate());
227     let option_variants = fam.core_option_Option()?.variants(sema.db);
228     match &*option_variants {
229         &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None {
230             (variant0, variant1)
231         } else {
232             (variant1, variant0)
233         }),
234         _ => None,
235     }
236 }
237
238 /// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
239 /// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
240 fn is_invalid_body(
241     sema: &Semantics<RootDatabase>,
242     some_variant: hir::Variant,
243     expr: &ast::Expr,
244 ) -> bool {
245     let mut invalid = false;
246     preorder_expr(expr, &mut |e| {
247         invalid |=
248             matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
249         invalid
250     });
251     if !invalid {
252         for_each_tail_expr(expr, &mut |e| {
253             if invalid {
254                 return;
255             }
256             let e = match e {
257                 ast::Expr::BreakExpr(e) => e.expr(),
258                 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
259                 _ => None,
260             };
261             if let Some(ast::Expr::CallExpr(call)) = e {
262                 if let Some(ast::Expr::PathExpr(p)) = call.expr() {
263                     let res = p.path().and_then(|p| sema.resolve_path(&p));
264                     if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
265                         return invalid |= v != some_variant;
266                     }
267                 }
268             }
269             invalid = true
270         });
271     }
272     invalid
273 }
274
275 fn block_is_none_variant(
276     sema: &Semantics<RootDatabase>,
277     block: &ast::BlockExpr,
278     none_variant: hir::Variant,
279 ) -> bool {
280     block_as_lone_tail(block).and_then(|e| match e {
281         ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
282             hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
283             _ => None,
284         },
285         _ => None,
286     }) == Some(none_variant)
287 }
288
289 #[cfg(test)]
290 mod tests {
291     use crate::tests::{check_assist, check_assist_not_applicable};
292
293     use super::*;
294
295     #[test]
296     fn convert_if_to_bool_then_simple() {
297         check_assist(
298             convert_if_to_bool_then,
299             r"
300 //- minicore:option
301 fn main() {
302     if$0 true {
303         Some(15)
304     } else {
305         None
306     }
307 }
308 ",
309             r"
310 fn main() {
311     true.then(|| 15)
312 }
313 ",
314         );
315     }
316
317     #[test]
318     fn convert_if_to_bool_then_invert() {
319         check_assist(
320             convert_if_to_bool_then,
321             r"
322 //- minicore:option
323 fn main() {
324     if$0 true {
325         None
326     } else {
327         Some(15)
328     }
329 }
330 ",
331             r"
332 fn main() {
333     false.then(|| 15)
334 }
335 ",
336         );
337     }
338
339     #[test]
340     fn convert_if_to_bool_then_none_none() {
341         check_assist_not_applicable(
342             convert_if_to_bool_then,
343             r"
344 //- minicore:option
345 fn main() {
346     if$0 true {
347         None
348     } else {
349         None
350     }
351 }
352 ",
353         );
354     }
355
356     #[test]
357     fn convert_if_to_bool_then_some_some() {
358         check_assist_not_applicable(
359             convert_if_to_bool_then,
360             r"
361 //- minicore:option
362 fn main() {
363     if$0 true {
364         Some(15)
365     } else {
366         Some(15)
367     }
368 }
369 ",
370         );
371     }
372
373     #[test]
374     fn convert_if_to_bool_then_mixed() {
375         check_assist_not_applicable(
376             convert_if_to_bool_then,
377             r"
378 //- minicore:option
379 fn main() {
380     if$0 true {
381         if true {
382             Some(15)
383         } else {
384             None
385         }
386     } else {
387         None
388     }
389 }
390 ",
391         );
392     }
393
394     #[test]
395     fn convert_if_to_bool_then_chain() {
396         cov_mark::check!(convert_if_to_bool_then_chain);
397         check_assist_not_applicable(
398             convert_if_to_bool_then,
399             r"
400 //- minicore:option
401 fn main() {
402     if$0 true {
403         Some(15)
404     } else if true {
405         None
406     } else {
407         None
408     }
409 }
410 ",
411         );
412     }
413
414     #[test]
415     fn convert_if_to_bool_then_pattern_cond() {
416         check_assist_not_applicable(
417             convert_if_to_bool_then,
418             r"
419 //- minicore:option
420 fn main() {
421     if$0 let true = true {
422         Some(15)
423     } else {
424         None
425     }
426 }
427 ",
428         );
429     }
430
431     #[test]
432     fn convert_if_to_bool_then_pattern_invalid_body() {
433         cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
434         check_assist_not_applicable(
435             convert_if_to_bool_then,
436             r"
437 //- minicore:option
438 fn make_me_an_option() -> Option<i32> { None }
439 fn main() {
440     if$0 true {
441         if true {
442             make_me_an_option()
443         } else {
444             Some(15)
445         }
446     } else {
447         None
448     }
449 }
450 ",
451         );
452         check_assist_not_applicable(
453             convert_if_to_bool_then,
454             r"
455 //- minicore:option
456 fn main() {
457     if$0 true {
458         if true {
459             return;
460         }
461         Some(15)
462     } else {
463         None
464     }
465 }
466 ",
467         );
468     }
469
470     #[test]
471     fn convert_bool_then_to_if_inapplicable() {
472         check_assist_not_applicable(
473             convert_bool_then_to_if,
474             r"
475 //- minicore:bool_impl
476 fn main() {
477     0.t$0hen(|| 15);
478 }
479 ",
480         );
481         check_assist_not_applicable(
482             convert_bool_then_to_if,
483             r"
484 //- minicore:bool_impl
485 fn main() {
486     true.t$0hen(15);
487 }
488 ",
489         );
490         check_assist_not_applicable(
491             convert_bool_then_to_if,
492             r"
493 //- minicore:bool_impl
494 fn main() {
495     true.t$0hen(|| 15, 15);
496 }
497 ",
498         );
499     }
500
501     #[test]
502     fn convert_bool_then_to_if_simple() {
503         check_assist(
504             convert_bool_then_to_if,
505             r"
506 //- minicore:bool_impl
507 fn main() {
508     true.t$0hen(|| 15)
509 }
510 ",
511             r"
512 fn main() {
513     if true {
514         Some(15)
515     } else {
516         None
517     }
518 }
519 ",
520         );
521         check_assist(
522             convert_bool_then_to_if,
523             r"
524 //- minicore:bool_impl
525 fn main() {
526     true.t$0hen(|| {
527         15
528     })
529 }
530 ",
531             r"
532 fn main() {
533     if true {
534         Some(15)
535     } else {
536         None
537     }
538 }
539 ",
540         );
541     }
542
543     #[test]
544     fn convert_bool_then_to_if_tails() {
545         check_assist(
546             convert_bool_then_to_if,
547             r"
548 //- minicore:bool_impl
549 fn main() {
550     true.t$0hen(|| {
551         loop {
552             if false {
553                 break 0;
554             }
555             break 15;
556         }
557     })
558 }
559 ",
560             r"
561 fn main() {
562     if true {
563         loop {
564             if false {
565                 break Some(0);
566             }
567             break Some(15);
568         }
569     } else {
570         None
571     }
572 }
573 ",
574         );
575     }
576 }