]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_bool_then.rs
Rename `*Owner` traits to `Has*`
[rust.git] / crates / ide_assists / src / handlers / convert_bool_then.rs
1 use hir::{known, AsAssocItem, Semantics};
2 use ide_db::{
3     helpers::{
4         for_each_tail_expr,
5         node_ext::{block_as_lone_tail, preorder_expr},
6         FamousDefs,
7     },
8     RootDatabase,
9 };
10 use itertools::Itertools;
11 use syntax::{
12     ast::{self, edit::AstNodeEdit, make, HasArgList},
13     ted, AstNode, SyntaxNode,
14 };
15
16 use crate::{
17     utils::{invert_boolean_expression, unwrap_trivial_block},
18     AssistContext, AssistId, AssistKind, Assists,
19 };
20
21 // Assist: convert_if_to_bool_then
22 //
23 // Converts an if expression into a corresponding `bool::then` call.
24 //
25 // ```
26 // # //- minicore: option
27 // fn main() {
28 //     if$0 cond {
29 //         Some(val)
30 //     } else {
31 //         None
32 //     }
33 // }
34 // ```
35 // ->
36 // ```
37 // fn main() {
38 //     cond.then(|| val)
39 // }
40 // ```
41 pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
42     // todo, applies to match as well
43     let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
44     if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
45         return None;
46     }
47
48     let cond = expr.condition().filter(|cond| !cond.is_pattern_cond())?;
49     let cond = cond.expr()?;
50     let then = expr.then_branch()?;
51     let else_ = match expr.else_branch()? {
52         ast::ElseBranch::Block(b) => b,
53         ast::ElseBranch::IfExpr(_) => {
54             cov_mark::hit!(convert_if_to_bool_then_chain);
55             return None;
56         }
57     };
58
59     let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
60
61     let (invert_cond, closure_body) = match (
62         block_is_none_variant(&ctx.sema, &then, none_variant),
63         block_is_none_variant(&ctx.sema, &else_, none_variant),
64     ) {
65         (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
66         (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
67         _ => return None,
68     };
69
70     if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
71         cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
72         return None;
73     }
74
75     let target = expr.syntax().text_range();
76     acc.add(
77         AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite),
78         "Convert `if` expression to `bool::then` call",
79         target,
80         |builder| {
81             let closure_body = closure_body.clone_for_update();
82             // Rewrite all `Some(e)` in tail position to `e`
83             let mut replacements = Vec::new();
84             for_each_tail_expr(&closure_body, &mut |e| {
85                 let e = match e {
86                     ast::Expr::BreakExpr(e) => e.expr(),
87                     e @ ast::Expr::CallExpr(_) => Some(e.clone()),
88                     _ => None,
89                 };
90                 if let Some(ast::Expr::CallExpr(call)) = e {
91                     if let Some(arg_list) = call.arg_list() {
92                         if let Some(arg) = arg_list.args().next() {
93                             replacements.push((call.syntax().clone(), arg.syntax().clone()));
94                         }
95                     }
96                 }
97             });
98             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
99             let closure_body = match closure_body {
100                 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
101                 e => e,
102             };
103
104             let cond = if invert_cond { invert_boolean_expression(cond) } else { cond };
105             let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
106             let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
107             builder.replace(target, mcall.to_string());
108         },
109     )
110 }
111
112 // Assist: convert_bool_then_to_if
113 //
114 // Converts a `bool::then` method call to an equivalent if expression.
115 //
116 // ```
117 // # //- minicore: bool_impl
118 // fn main() {
119 //     (0 == 0).then$0(|| val)
120 // }
121 // ```
122 // ->
123 // ```
124 // fn main() {
125 //     if 0 == 0 {
126 //         Some(val)
127 //     } else {
128 //         None
129 //     }
130 // }
131 // ```
132 pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
133     let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
134     let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
135     let receiver = mcall.receiver()?;
136     let closure_body = mcall.arg_list()?.args().exactly_one().ok()?;
137     let closure_body = match closure_body {
138         ast::Expr::ClosureExpr(expr) => expr.body()?,
139         _ => return None,
140     };
141     // Verify this is `bool::then` that is being called.
142     let func = ctx.sema.resolve_method_call(&mcall)?;
143     if func.name(ctx.sema.db).to_string() != "then" {
144         return None;
145     }
146     let assoc = func.as_assoc_item(ctx.sema.db)?;
147     match assoc.container(ctx.sema.db) {
148         hir::AssocItemContainer::Impl(impl_) if impl_.self_ty(ctx.sema.db).is_bool() => {}
149         _ => return None,
150     }
151
152     let target = mcall.syntax().text_range();
153     acc.add(
154         AssistId("convert_bool_then_to_if", AssistKind::RefactorRewrite),
155         "Convert `bool::then` call to `if`",
156         target,
157         |builder| {
158             let closure_body = match closure_body {
159                 ast::Expr::BlockExpr(block) => block,
160                 e => make::block_expr(None, Some(e)),
161             };
162
163             let closure_body = closure_body.clone_for_update();
164             // Wrap all tails in `Some(...)`
165             let none_path = make::expr_path(make::ext::ident_path("None"));
166             let some_path = make::expr_path(make::ext::ident_path("Some"));
167             let mut replacements = Vec::new();
168             for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| {
169                 let e = match e {
170                     ast::Expr::BreakExpr(e) => e.expr(),
171                     ast::Expr::ReturnExpr(e) => e.expr(),
172                     _ => Some(e.clone()),
173                 };
174                 if let Some(expr) = e {
175                     replacements.push((
176                         expr.syntax().clone(),
177                         make::expr_call(some_path.clone(), make::arg_list(Some(expr)))
178                             .syntax()
179                             .clone_for_update(),
180                     ));
181                 }
182             });
183             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
184
185             let cond = match &receiver {
186                 ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
187                 _ => receiver,
188             };
189             let if_expr = make::expr_if(
190                 make::condition(cond, None),
191                 closure_body.reset_indent(),
192                 Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))),
193             )
194             .indent(mcall.indent_level());
195
196             builder.replace(target, if_expr.to_string());
197         },
198     )
199 }
200
201 fn option_variants(
202     sema: &Semantics<RootDatabase>,
203     expr: &SyntaxNode,
204 ) -> Option<(hir::Variant, hir::Variant)> {
205     let fam = FamousDefs(sema, sema.scope(expr).krate());
206     let option_variants = fam.core_option_Option()?.variants(sema.db);
207     match &*option_variants {
208         &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None {
209             (variant0, variant1)
210         } else {
211             (variant1, variant0)
212         }),
213         _ => None,
214     }
215 }
216
217 /// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
218 /// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
219 fn is_invalid_body(
220     sema: &Semantics<RootDatabase>,
221     some_variant: hir::Variant,
222     expr: &ast::Expr,
223 ) -> bool {
224     let mut invalid = false;
225     preorder_expr(expr, &mut |e| {
226         invalid |=
227             matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
228         invalid
229     });
230     if !invalid {
231         for_each_tail_expr(expr, &mut |e| {
232             if invalid {
233                 return;
234             }
235             let e = match e {
236                 ast::Expr::BreakExpr(e) => e.expr(),
237                 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
238                 _ => None,
239             };
240             if let Some(ast::Expr::CallExpr(call)) = e {
241                 if let Some(ast::Expr::PathExpr(p)) = call.expr() {
242                     let res = p.path().and_then(|p| sema.resolve_path(&p));
243                     if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
244                         return invalid |= v != some_variant;
245                     }
246                 }
247             }
248             invalid = true
249         });
250     }
251     invalid
252 }
253
254 fn block_is_none_variant(
255     sema: &Semantics<RootDatabase>,
256     block: &ast::BlockExpr,
257     none_variant: hir::Variant,
258 ) -> bool {
259     block_as_lone_tail(block).and_then(|e| match e {
260         ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
261             hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
262             _ => None,
263         },
264         _ => None,
265     }) == Some(none_variant)
266 }
267
268 #[cfg(test)]
269 mod tests {
270     use crate::tests::{check_assist, check_assist_not_applicable};
271
272     use super::*;
273
274     #[test]
275     fn convert_if_to_bool_then_simple() {
276         check_assist(
277             convert_if_to_bool_then,
278             r"
279 //- minicore:option
280 fn main() {
281     if$0 true {
282         Some(15)
283     } else {
284         None
285     }
286 }
287 ",
288             r"
289 fn main() {
290     true.then(|| 15)
291 }
292 ",
293         );
294     }
295
296     #[test]
297     fn convert_if_to_bool_then_invert() {
298         check_assist(
299             convert_if_to_bool_then,
300             r"
301 //- minicore:option
302 fn main() {
303     if$0 true {
304         None
305     } else {
306         Some(15)
307     }
308 }
309 ",
310             r"
311 fn main() {
312     false.then(|| 15)
313 }
314 ",
315         );
316     }
317
318     #[test]
319     fn convert_if_to_bool_then_none_none() {
320         check_assist_not_applicable(
321             convert_if_to_bool_then,
322             r"
323 //- minicore:option
324 fn main() {
325     if$0 true {
326         None
327     } else {
328         None
329     }
330 }
331 ",
332         );
333     }
334
335     #[test]
336     fn convert_if_to_bool_then_some_some() {
337         check_assist_not_applicable(
338             convert_if_to_bool_then,
339             r"
340 //- minicore:option
341 fn main() {
342     if$0 true {
343         Some(15)
344     } else {
345         Some(15)
346     }
347 }
348 ",
349         );
350     }
351
352     #[test]
353     fn convert_if_to_bool_then_mixed() {
354         check_assist_not_applicable(
355             convert_if_to_bool_then,
356             r"
357 //- minicore:option
358 fn main() {
359     if$0 true {
360         if true {
361             Some(15)
362         } else {
363             None
364         }
365     } else {
366         None
367     }
368 }
369 ",
370         );
371     }
372
373     #[test]
374     fn convert_if_to_bool_then_chain() {
375         cov_mark::check!(convert_if_to_bool_then_chain);
376         check_assist_not_applicable(
377             convert_if_to_bool_then,
378             r"
379 //- minicore:option
380 fn main() {
381     if$0 true {
382         Some(15)
383     } else if true {
384         None
385     } else {
386         None
387     }
388 }
389 ",
390         );
391     }
392
393     #[test]
394     fn convert_if_to_bool_then_pattern_cond() {
395         check_assist_not_applicable(
396             convert_if_to_bool_then,
397             r"
398 //- minicore:option
399 fn main() {
400     if$0 let true = true {
401         Some(15)
402     } else {
403         None
404     }
405 }
406 ",
407         );
408     }
409
410     #[test]
411     fn convert_if_to_bool_then_pattern_invalid_body() {
412         cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
413         check_assist_not_applicable(
414             convert_if_to_bool_then,
415             r"
416 //- minicore:option
417 fn make_me_an_option() -> Option<i32> { None }
418 fn main() {
419     if$0 true {
420         if true {
421             make_me_an_option()
422         } else {
423             Some(15)
424         }
425     } else {
426         None
427     }
428 }
429 ",
430         );
431         check_assist_not_applicable(
432             convert_if_to_bool_then,
433             r"
434 //- minicore:option
435 fn main() {
436     if$0 true {
437         if true {
438             return;
439         }
440         Some(15)
441     } else {
442         None
443     }
444 }
445 ",
446         );
447     }
448
449     #[test]
450     fn convert_bool_then_to_if_inapplicable() {
451         check_assist_not_applicable(
452             convert_bool_then_to_if,
453             r"
454 //- minicore:bool_impl
455 fn main() {
456     0.t$0hen(|| 15);
457 }
458 ",
459         );
460         check_assist_not_applicable(
461             convert_bool_then_to_if,
462             r"
463 //- minicore:bool_impl
464 fn main() {
465     true.t$0hen(15);
466 }
467 ",
468         );
469         check_assist_not_applicable(
470             convert_bool_then_to_if,
471             r"
472 //- minicore:bool_impl
473 fn main() {
474     true.t$0hen(|| 15, 15);
475 }
476 ",
477         );
478     }
479
480     #[test]
481     fn convert_bool_then_to_if_simple() {
482         check_assist(
483             convert_bool_then_to_if,
484             r"
485 //- minicore:bool_impl
486 fn main() {
487     true.t$0hen(|| 15)
488 }
489 ",
490             r"
491 fn main() {
492     if true {
493         Some(15)
494     } else {
495         None
496     }
497 }
498 ",
499         );
500         check_assist(
501             convert_bool_then_to_if,
502             r"
503 //- minicore:bool_impl
504 fn main() {
505     true.t$0hen(|| {
506         15
507     })
508 }
509 ",
510             r"
511 fn main() {
512     if true {
513         Some(15)
514     } else {
515         None
516     }
517 }
518 ",
519         );
520     }
521
522     #[test]
523     fn convert_bool_then_to_if_tails() {
524         check_assist(
525             convert_bool_then_to_if,
526             r"
527 //- minicore:bool_impl
528 fn main() {
529     true.t$0hen(|| {
530         loop {
531             if false {
532                 break 0;
533             }
534             break 15;
535         }
536     })
537 }
538 ",
539             r"
540 fn main() {
541     if true {
542         loop {
543             if false {
544                 break Some(0);
545             }
546             break Some(15);
547         }
548     } else {
549         None
550     }
551 }
552 ",
553         );
554     }
555 }