]> git.lizzy.rs Git - rust.git/blob - src/tools/clippy/clippy_lints/src/loops/manual_find.rs
Rollup merge of #106328 - GuillaumeGomez:gui-test-explanation, r=notriddle
[rust.git] / src / tools / clippy / clippy_lints / src / loops / manual_find.rs
1 use super::utils::make_iterator_snippet;
2 use super::MANUAL_FIND;
3 use clippy_utils::{
4     diagnostics::span_lint_and_then, higher, is_res_lang_ctor, path_res, peel_blocks_with_stmt,
5     source::snippet_with_applicability, ty::implements_trait,
6 };
7 use if_chain::if_chain;
8 use rustc_errors::Applicability;
9 use rustc_hir::{
10     def::Res, lang_items::LangItem, BindingAnnotation, Block, Expr, ExprKind, HirId, Node, Pat, PatKind, Stmt, StmtKind,
11 };
12 use rustc_lint::LateContext;
13 use rustc_span::source_map::Span;
14
15 pub(super) fn check<'tcx>(
16     cx: &LateContext<'tcx>,
17     pat: &'tcx Pat<'_>,
18     arg: &'tcx Expr<'_>,
19     body: &'tcx Expr<'_>,
20     span: Span,
21     expr: &'tcx Expr<'_>,
22 ) {
23     let inner_expr = peel_blocks_with_stmt(body);
24     // Check for the specific case that the result is returned and optimize suggestion for that (more
25     // cases can be added later)
26     if_chain! {
27         if let Some(higher::If { cond, then, r#else: None, }) = higher::If::hir(inner_expr);
28         if let Some(binding_id) = get_binding(pat);
29         if let ExprKind::Block(block, _) = then.kind;
30         if let [stmt] = block.stmts;
31         if let StmtKind::Semi(semi) = stmt.kind;
32         if let ExprKind::Ret(Some(ret_value)) = semi.kind;
33         if let ExprKind::Call(ctor, [inner_ret]) = ret_value.kind;
34         if is_res_lang_ctor(cx, path_res(cx, ctor), LangItem::OptionSome);
35         if path_res(cx, inner_ret) == Res::Local(binding_id);
36         if let Some((last_stmt, last_ret)) = last_stmt_and_ret(cx, expr);
37         then {
38             let mut applicability = Applicability::MachineApplicable;
39             let mut snippet = make_iterator_snippet(cx, arg, &mut applicability);
40             // Checks if `pat` is a single reference to a binding (`&x`)
41             let is_ref_to_binding =
42                 matches!(pat.kind, PatKind::Ref(inner, _) if matches!(inner.kind, PatKind::Binding(..)));
43             // If `pat` is not a binding or a reference to a binding (`x` or `&x`)
44             // we need to map it to the binding returned by the function (i.e. `.map(|(x, _)| x)`)
45             if !(matches!(pat.kind, PatKind::Binding(..)) || is_ref_to_binding) {
46                 snippet.push_str(
47                     &format!(
48                         ".map(|{}| {})",
49                         snippet_with_applicability(cx, pat.span, "..", &mut applicability),
50                         snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability),
51                     )[..],
52                 );
53             }
54             let ty = cx.typeck_results().expr_ty(inner_ret);
55             if cx.tcx.lang_items().copy_trait().map_or(false, |id| implements_trait(cx, ty, id, &[])) {
56                 snippet.push_str(
57                     &format!(
58                         ".find(|{}{}| {})",
59                         "&".repeat(1 + usize::from(is_ref_to_binding)),
60                         snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability),
61                         snippet_with_applicability(cx, cond.span, "..", &mut applicability),
62                     )[..],
63                 );
64                 if is_ref_to_binding {
65                     snippet.push_str(".copied()");
66                 }
67             } else {
68                 applicability = Applicability::MaybeIncorrect;
69                 snippet.push_str(
70                     &format!(
71                         ".find(|{}| {})",
72                         snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability),
73                         snippet_with_applicability(cx, cond.span, "..", &mut applicability),
74                     )[..],
75                 );
76             }
77             // Extends to `last_stmt` to include semicolon in case of `return None;`
78             let lint_span = span.to(last_stmt.span).to(last_ret.span);
79             span_lint_and_then(
80                 cx,
81                 MANUAL_FIND,
82                 lint_span,
83                 "manual implementation of `Iterator::find`",
84                 |diag| {
85                     if applicability == Applicability::MaybeIncorrect {
86                         diag.note("you may need to dereference some variables");
87                     }
88                     diag.span_suggestion(
89                         lint_span,
90                         "replace with an iterator",
91                         snippet,
92                         applicability,
93                     );
94                 },
95             );
96         }
97     }
98 }
99
100 fn get_binding(pat: &Pat<'_>) -> Option<HirId> {
101     let mut hir_id = None;
102     let mut count = 0;
103     pat.each_binding(|annotation, id, _, _| {
104         count += 1;
105         if count > 1 {
106             hir_id = None;
107             return;
108         }
109         if let BindingAnnotation::NONE = annotation {
110             hir_id = Some(id);
111         }
112     });
113     hir_id
114 }
115
116 // Returns the last statement and last return if function fits format for lint
117 fn last_stmt_and_ret<'tcx>(
118     cx: &LateContext<'tcx>,
119     expr: &'tcx Expr<'_>,
120 ) -> Option<(&'tcx Stmt<'tcx>, &'tcx Expr<'tcx>)> {
121     // Returns last non-return statement and the last return
122     fn extract<'tcx>(block: &Block<'tcx>) -> Option<(&'tcx Stmt<'tcx>, &'tcx Expr<'tcx>)> {
123         if let [.., last_stmt] = block.stmts {
124             if let Some(ret) = block.expr {
125                 return Some((last_stmt, ret));
126             }
127             if_chain! {
128                 if let [.., snd_last, _] = block.stmts;
129                 if let StmtKind::Semi(last_expr) = last_stmt.kind;
130                 if let ExprKind::Ret(Some(ret)) = last_expr.kind;
131                 then {
132                     return Some((snd_last, ret));
133                 }
134             }
135         }
136         None
137     }
138     let mut parent_iter = cx.tcx.hir().parent_iter(expr.hir_id);
139     if_chain! {
140         // This should be the loop
141         if let Some((node_hir, Node::Stmt(..))) = parent_iter.next();
142         // This should be the function body
143         if let Some((_, Node::Block(block))) = parent_iter.next();
144         if let Some((last_stmt, last_ret)) = extract(block);
145         if last_stmt.hir_id == node_hir;
146         if is_res_lang_ctor(cx, path_res(cx, last_ret), LangItem::OptionNone);
147         if let Some((_, Node::Expr(_block))) = parent_iter.next();
148         // This includes the function header
149         if let Some((_, func)) = parent_iter.next();
150         if func.fn_kind().is_some();
151         then {
152             Some((block.stmts.last().unwrap(), last_ret))
153         } else {
154             None
155         }
156     }
157 }