]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs
Rollup merge of #99544 - dylni:expose-utf8lossy, r=Mark-Simulacrum
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / convert_bool_then.rs
1 use hir::{known, AsAssocItem, Semantics};
2 use ide_db::{
3     famous_defs::FamousDefs,
4     syntax_helpers::node_ext::{
5         block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr,
6     },
7     RootDatabase,
8 };
9 use itertools::Itertools;
10 use syntax::{
11     ast::{self, edit::AstNodeEdit, make, HasArgList},
12     ted, AstNode, SyntaxNode,
13 };
14
15 use crate::{
16     utils::{invert_boolean_expression, unwrap_trivial_block},
17     AssistContext, AssistId, AssistKind, Assists,
18 };
19
20 // Assist: convert_if_to_bool_then
21 //
22 // Converts an if expression into a corresponding `bool::then` call.
23 //
24 // ```
25 // # //- minicore: option
26 // fn main() {
27 //     if$0 cond {
28 //         Some(val)
29 //     } else {
30 //         None
31 //     }
32 // }
33 // ```
34 // ->
35 // ```
36 // fn main() {
37 //     cond.then(|| val)
38 // }
39 // ```
40 pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
41     // FIXME applies to match as well
42     let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
43     if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
44         return None;
45     }
46
47     let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
48     let then = expr.then_branch()?;
49     let else_ = match expr.else_branch()? {
50         ast::ElseBranch::Block(b) => b,
51         ast::ElseBranch::IfExpr(_) => {
52             cov_mark::hit!(convert_if_to_bool_then_chain);
53             return None;
54         }
55     };
56
57     let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
58
59     let (invert_cond, closure_body) = match (
60         block_is_none_variant(&ctx.sema, &then, none_variant),
61         block_is_none_variant(&ctx.sema, &else_, none_variant),
62     ) {
63         (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
64         (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
65         _ => return None,
66     };
67
68     if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
69         cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
70         return None;
71     }
72
73     let target = expr.syntax().text_range();
74     acc.add(
75         AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite),
76         "Convert `if` expression to `bool::then` call",
77         target,
78         |builder| {
79             let closure_body = closure_body.clone_for_update();
80             // Rewrite all `Some(e)` in tail position to `e`
81             let mut replacements = Vec::new();
82             for_each_tail_expr(&closure_body, &mut |e| {
83                 let e = match e {
84                     ast::Expr::BreakExpr(e) => e.expr(),
85                     e @ ast::Expr::CallExpr(_) => Some(e.clone()),
86                     _ => None,
87                 };
88                 if let Some(ast::Expr::CallExpr(call)) = e {
89                     if let Some(arg_list) = call.arg_list() {
90                         if let Some(arg) = arg_list.args().next() {
91                             replacements.push((call.syntax().clone(), arg.syntax().clone()));
92                         }
93                     }
94                 }
95             });
96             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
97             let closure_body = match closure_body {
98                 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
99                 e => e,
100             };
101
102             let parenthesize = matches!(
103                 cond,
104                 ast::Expr::BinExpr(_)
105                     | ast::Expr::BlockExpr(_)
106                     | ast::Expr::BoxExpr(_)
107                     | ast::Expr::BreakExpr(_)
108                     | ast::Expr::CastExpr(_)
109                     | ast::Expr::ClosureExpr(_)
110                     | ast::Expr::ContinueExpr(_)
111                     | ast::Expr::ForExpr(_)
112                     | ast::Expr::IfExpr(_)
113                     | ast::Expr::LoopExpr(_)
114                     | ast::Expr::MacroExpr(_)
115                     | ast::Expr::MatchExpr(_)
116                     | ast::Expr::PrefixExpr(_)
117                     | ast::Expr::RangeExpr(_)
118                     | ast::Expr::RefExpr(_)
119                     | ast::Expr::ReturnExpr(_)
120                     | ast::Expr::WhileExpr(_)
121                     | ast::Expr::YieldExpr(_)
122             );
123             let cond = if invert_cond { invert_boolean_expression(cond) } else { cond };
124             let cond = if parenthesize { make::expr_paren(cond) } else { cond };
125             let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
126             let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
127             builder.replace(target, mcall.to_string());
128         },
129     )
130 }
131
132 // Assist: convert_bool_then_to_if
133 //
134 // Converts a `bool::then` method call to an equivalent if expression.
135 //
136 // ```
137 // # //- minicore: bool_impl
138 // fn main() {
139 //     (0 == 0).then$0(|| val)
140 // }
141 // ```
142 // ->
143 // ```
144 // fn main() {
145 //     if 0 == 0 {
146 //         Some(val)
147 //     } else {
148 //         None
149 //     }
150 // }
151 // ```
152 pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
153     let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
154     let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
155     let receiver = mcall.receiver()?;
156     let closure_body = mcall.arg_list()?.args().exactly_one().ok()?;
157     let closure_body = match closure_body {
158         ast::Expr::ClosureExpr(expr) => expr.body()?,
159         _ => return None,
160     };
161     // Verify this is `bool::then` that is being called.
162     let func = ctx.sema.resolve_method_call(&mcall)?;
163     if func.name(ctx.sema.db).to_string() != "then" {
164         return None;
165     }
166     let assoc = func.as_assoc_item(ctx.sema.db)?;
167     match assoc.container(ctx.sema.db) {
168         hir::AssocItemContainer::Impl(impl_) if impl_.self_ty(ctx.sema.db).is_bool() => {}
169         _ => return None,
170     }
171
172     let target = mcall.syntax().text_range();
173     acc.add(
174         AssistId("convert_bool_then_to_if", AssistKind::RefactorRewrite),
175         "Convert `bool::then` call to `if`",
176         target,
177         |builder| {
178             let closure_body = match closure_body {
179                 ast::Expr::BlockExpr(block) => block,
180                 e => make::block_expr(None, Some(e)),
181             };
182
183             let closure_body = closure_body.clone_for_update();
184             // Wrap all tails in `Some(...)`
185             let none_path = make::expr_path(make::ext::ident_path("None"));
186             let some_path = make::expr_path(make::ext::ident_path("Some"));
187             let mut replacements = Vec::new();
188             for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| {
189                 let e = match e {
190                     ast::Expr::BreakExpr(e) => e.expr(),
191                     ast::Expr::ReturnExpr(e) => e.expr(),
192                     _ => Some(e.clone()),
193                 };
194                 if let Some(expr) = e {
195                     replacements.push((
196                         expr.syntax().clone(),
197                         make::expr_call(some_path.clone(), make::arg_list(Some(expr)))
198                             .syntax()
199                             .clone_for_update(),
200                     ));
201                 }
202             });
203             replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
204
205             let cond = match &receiver {
206                 ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
207                 _ => receiver,
208             };
209             let if_expr = make::expr_if(
210                 cond,
211                 closure_body.reset_indent(),
212                 Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))),
213             )
214             .indent(mcall.indent_level());
215
216             builder.replace(target, if_expr.to_string());
217         },
218     )
219 }
220
221 fn option_variants(
222     sema: &Semantics<'_, RootDatabase>,
223     expr: &SyntaxNode,
224 ) -> Option<(hir::Variant, hir::Variant)> {
225     let fam = FamousDefs(sema, sema.scope(expr)?.krate());
226     let option_variants = fam.core_option_Option()?.variants(sema.db);
227     match &*option_variants {
228         &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None {
229             (variant0, variant1)
230         } else {
231             (variant1, variant0)
232         }),
233         _ => None,
234     }
235 }
236
237 /// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
238 /// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
239 fn is_invalid_body(
240     sema: &Semantics<'_, RootDatabase>,
241     some_variant: hir::Variant,
242     expr: &ast::Expr,
243 ) -> bool {
244     let mut invalid = false;
245     preorder_expr(expr, &mut |e| {
246         invalid |=
247             matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
248         invalid
249     });
250     if !invalid {
251         for_each_tail_expr(expr, &mut |e| {
252             if invalid {
253                 return;
254             }
255             let e = match e {
256                 ast::Expr::BreakExpr(e) => e.expr(),
257                 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
258                 _ => None,
259             };
260             if let Some(ast::Expr::CallExpr(call)) = e {
261                 if let Some(ast::Expr::PathExpr(p)) = call.expr() {
262                     let res = p.path().and_then(|p| sema.resolve_path(&p));
263                     if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
264                         return invalid |= v != some_variant;
265                     }
266                 }
267             }
268             invalid = true
269         });
270     }
271     invalid
272 }
273
274 fn block_is_none_variant(
275     sema: &Semantics<'_, RootDatabase>,
276     block: &ast::BlockExpr,
277     none_variant: hir::Variant,
278 ) -> bool {
279     block_as_lone_tail(block).and_then(|e| match e {
280         ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
281             hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
282             _ => None,
283         },
284         _ => None,
285     }) == Some(none_variant)
286 }
287
288 #[cfg(test)]
289 mod tests {
290     use crate::tests::{check_assist, check_assist_not_applicable};
291
292     use super::*;
293
294     #[test]
295     fn convert_if_to_bool_then_simple() {
296         check_assist(
297             convert_if_to_bool_then,
298             r"
299 //- minicore:option
300 fn main() {
301     if$0 true {
302         Some(15)
303     } else {
304         None
305     }
306 }
307 ",
308             r"
309 fn main() {
310     true.then(|| 15)
311 }
312 ",
313         );
314     }
315
316     #[test]
317     fn convert_if_to_bool_then_invert() {
318         check_assist(
319             convert_if_to_bool_then,
320             r"
321 //- minicore:option
322 fn main() {
323     if$0 true {
324         None
325     } else {
326         Some(15)
327     }
328 }
329 ",
330             r"
331 fn main() {
332     false.then(|| 15)
333 }
334 ",
335         );
336     }
337
338     #[test]
339     fn convert_if_to_bool_then_none_none() {
340         check_assist_not_applicable(
341             convert_if_to_bool_then,
342             r"
343 //- minicore:option
344 fn main() {
345     if$0 true {
346         None
347     } else {
348         None
349     }
350 }
351 ",
352         );
353     }
354
355     #[test]
356     fn convert_if_to_bool_then_some_some() {
357         check_assist_not_applicable(
358             convert_if_to_bool_then,
359             r"
360 //- minicore:option
361 fn main() {
362     if$0 true {
363         Some(15)
364     } else {
365         Some(15)
366     }
367 }
368 ",
369         );
370     }
371
372     #[test]
373     fn convert_if_to_bool_then_mixed() {
374         check_assist_not_applicable(
375             convert_if_to_bool_then,
376             r"
377 //- minicore:option
378 fn main() {
379     if$0 true {
380         if true {
381             Some(15)
382         } else {
383             None
384         }
385     } else {
386         None
387     }
388 }
389 ",
390         );
391     }
392
393     #[test]
394     fn convert_if_to_bool_then_chain() {
395         cov_mark::check!(convert_if_to_bool_then_chain);
396         check_assist_not_applicable(
397             convert_if_to_bool_then,
398             r"
399 //- minicore:option
400 fn main() {
401     if$0 true {
402         Some(15)
403     } else if true {
404         None
405     } else {
406         None
407     }
408 }
409 ",
410         );
411     }
412
413     #[test]
414     fn convert_if_to_bool_then_pattern_cond() {
415         check_assist_not_applicable(
416             convert_if_to_bool_then,
417             r"
418 //- minicore:option
419 fn main() {
420     if$0 let true = true {
421         Some(15)
422     } else {
423         None
424     }
425 }
426 ",
427         );
428     }
429
430     #[test]
431     fn convert_if_to_bool_then_pattern_invalid_body() {
432         cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
433         check_assist_not_applicable(
434             convert_if_to_bool_then,
435             r"
436 //- minicore:option
437 fn make_me_an_option() -> Option<i32> { None }
438 fn main() {
439     if$0 true {
440         if true {
441             make_me_an_option()
442         } else {
443             Some(15)
444         }
445     } else {
446         None
447     }
448 }
449 ",
450         );
451         check_assist_not_applicable(
452             convert_if_to_bool_then,
453             r"
454 //- minicore:option
455 fn main() {
456     if$0 true {
457         if true {
458             return;
459         }
460         Some(15)
461     } else {
462         None
463     }
464 }
465 ",
466         );
467     }
468
469     #[test]
470     fn convert_bool_then_to_if_inapplicable() {
471         check_assist_not_applicable(
472             convert_bool_then_to_if,
473             r"
474 //- minicore:bool_impl
475 fn main() {
476     0.t$0hen(|| 15);
477 }
478 ",
479         );
480         check_assist_not_applicable(
481             convert_bool_then_to_if,
482             r"
483 //- minicore:bool_impl
484 fn main() {
485     true.t$0hen(15);
486 }
487 ",
488         );
489         check_assist_not_applicable(
490             convert_bool_then_to_if,
491             r"
492 //- minicore:bool_impl
493 fn main() {
494     true.t$0hen(|| 15, 15);
495 }
496 ",
497         );
498     }
499
500     #[test]
501     fn convert_bool_then_to_if_simple() {
502         check_assist(
503             convert_bool_then_to_if,
504             r"
505 //- minicore:bool_impl
506 fn main() {
507     true.t$0hen(|| 15)
508 }
509 ",
510             r"
511 fn main() {
512     if true {
513         Some(15)
514     } else {
515         None
516     }
517 }
518 ",
519         );
520         check_assist(
521             convert_bool_then_to_if,
522             r"
523 //- minicore:bool_impl
524 fn main() {
525     true.t$0hen(|| {
526         15
527     })
528 }
529 ",
530             r"
531 fn main() {
532     if true {
533         Some(15)
534     } else {
535         None
536     }
537 }
538 ",
539         );
540     }
541
542     #[test]
543     fn convert_bool_then_to_if_tails() {
544         check_assist(
545             convert_bool_then_to_if,
546             r"
547 //- minicore:bool_impl
548 fn main() {
549     true.t$0hen(|| {
550         loop {
551             if false {
552                 break 0;
553             }
554             break 15;
555         }
556     })
557 }
558 ",
559             r"
560 fn main() {
561     if true {
562         loop {
563             if false {
564                 break Some(0);
565             }
566             break Some(15);
567         }
568     } else {
569         None
570     }
571 }
572 ",
573         );
574     }
575 }