]> git.lizzy.rs Git - rust.git/blob - clippy_lints/src/loops/needless_collect.rs
Auto merge of #92805 - BoxyUwU:revert-lazy-anon-const-substs, r=lcnr
[rust.git] / clippy_lints / src / loops / needless_collect.rs
1 use super::NEEDLESS_COLLECT;
2 use clippy_utils::diagnostics::{span_lint_and_sugg, span_lint_hir_and_then};
3 use clippy_utils::source::{snippet, snippet_with_applicability};
4 use clippy_utils::sugg::Sugg;
5 use clippy_utils::ty::is_type_diagnostic_item;
6 use clippy_utils::{can_move_expr_to_closure, is_trait_method, path_to_local, path_to_local_id, CaptureKind};
7 use if_chain::if_chain;
8 use rustc_data_structures::fx::FxHashMap;
9 use rustc_errors::Applicability;
10 use rustc_hir::intravisit::{walk_block, walk_expr, NestedVisitorMap, Visitor};
11 use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, Local, Mutability, Node, PatKind, Stmt, StmtKind};
12 use rustc_lint::LateContext;
13 use rustc_middle::hir::map::Map;
14 use rustc_middle::ty::subst::GenericArgKind;
15 use rustc_middle::ty::{self, TyS};
16 use rustc_span::sym;
17 use rustc_span::{MultiSpan, Span};
18
19 const NEEDLESS_COLLECT_MSG: &str = "avoid using `collect()` when not needed";
20
21 pub(super) fn check<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
22     check_needless_collect_direct_usage(expr, cx);
23     check_needless_collect_indirect_usage(expr, cx);
24 }
25 fn check_needless_collect_direct_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
26     if_chain! {
27         if let ExprKind::MethodCall(method, _, args, _) = expr.kind;
28         if let ExprKind::MethodCall(chain_method, method0_span, _, _) = args[0].kind;
29         if chain_method.ident.name == sym!(collect) && is_trait_method(cx, &args[0], sym::Iterator);
30         then {
31             let ty = cx.typeck_results().expr_ty(&args[0]);
32             let mut applicability = Applicability::MaybeIncorrect;
33             let is_empty_sugg = "next().is_none()".to_string();
34             let method_name = method.ident.name.as_str();
35             let sugg = if is_type_diagnostic_item(cx, ty, sym::Vec) ||
36                         is_type_diagnostic_item(cx, ty, sym::VecDeque) ||
37                         is_type_diagnostic_item(cx, ty, sym::LinkedList) ||
38                         is_type_diagnostic_item(cx, ty, sym::BinaryHeap) {
39                 match method_name {
40                     "len" => "count()".to_string(),
41                     "is_empty" => is_empty_sugg,
42                     "contains" => {
43                         let contains_arg = snippet_with_applicability(cx, args[1].span, "??", &mut applicability);
44                         let (arg, pred) = contains_arg
45                             .strip_prefix('&')
46                             .map_or(("&x", &*contains_arg), |s| ("x", s));
47                         format!("any(|{}| x == {})", arg, pred)
48                     }
49                     _ => return,
50                 }
51             }
52             else if is_type_diagnostic_item(cx, ty, sym::BTreeMap) ||
53                 is_type_diagnostic_item(cx, ty, sym::HashMap) {
54                 match method_name {
55                     "is_empty" => is_empty_sugg,
56                     _ => return,
57                 }
58             }
59             else {
60                 return;
61             };
62             span_lint_and_sugg(
63                 cx,
64                 NEEDLESS_COLLECT,
65                 method0_span.with_hi(expr.span.hi()),
66                 NEEDLESS_COLLECT_MSG,
67                 "replace with",
68                 sugg,
69                 applicability,
70             );
71         }
72     }
73 }
74
75 fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
76     if let ExprKind::Block(block, _) = expr.kind {
77         for stmt in block.stmts {
78             if_chain! {
79                 if let StmtKind::Local(local) = stmt.kind;
80                 if let PatKind::Binding(_, id, ..) = local.pat.kind;
81                 if let Some(init_expr) = local.init;
82                 if let ExprKind::MethodCall(method_name, collect_span, &[ref iter_source], ..) = init_expr.kind;
83                 if method_name.ident.name == sym!(collect) && is_trait_method(cx, init_expr, sym::Iterator);
84                 let ty = cx.typeck_results().expr_ty(init_expr);
85                 if is_type_diagnostic_item(cx, ty, sym::Vec) ||
86                     is_type_diagnostic_item(cx, ty, sym::VecDeque) ||
87                     is_type_diagnostic_item(cx, ty, sym::BinaryHeap) ||
88                     is_type_diagnostic_item(cx, ty, sym::LinkedList);
89                 let iter_ty = cx.typeck_results().expr_ty(iter_source);
90                 if let Some(iter_calls) = detect_iter_and_into_iters(block, id, cx, get_captured_ids(cx, iter_ty));
91                 if let [iter_call] = &*iter_calls;
92                 then {
93                     let mut used_count_visitor = UsedCountVisitor {
94                         cx,
95                         id,
96                         count: 0,
97                     };
98                     walk_block(&mut used_count_visitor, block);
99                     if used_count_visitor.count > 1 {
100                         return;
101                     }
102
103                     // Suggest replacing iter_call with iter_replacement, and removing stmt
104                     let mut span = MultiSpan::from_span(collect_span);
105                     span.push_span_label(iter_call.span, "the iterator could be used here instead".into());
106                     span_lint_hir_and_then(
107                         cx,
108                         super::NEEDLESS_COLLECT,
109                         init_expr.hir_id,
110                         span,
111                         NEEDLESS_COLLECT_MSG,
112                         |diag| {
113                             let iter_replacement = format!("{}{}", Sugg::hir(cx, iter_source, ".."), iter_call.get_iter_method(cx));
114                             diag.multipart_suggestion(
115                                 iter_call.get_suggestion_text(),
116                                 vec![
117                                     (stmt.span, String::new()),
118                                     (iter_call.span, iter_replacement)
119                                 ],
120                                 Applicability::MaybeIncorrect,
121                             );
122                         },
123                     );
124                 }
125             }
126         }
127     }
128 }
129
130 struct IterFunction {
131     func: IterFunctionKind,
132     span: Span,
133 }
134 impl IterFunction {
135     fn get_iter_method(&self, cx: &LateContext<'_>) -> String {
136         match &self.func {
137             IterFunctionKind::IntoIter => String::new(),
138             IterFunctionKind::Len => String::from(".count()"),
139             IterFunctionKind::IsEmpty => String::from(".next().is_none()"),
140             IterFunctionKind::Contains(span) => {
141                 let s = snippet(cx, *span, "..");
142                 if let Some(stripped) = s.strip_prefix('&') {
143                     format!(".any(|x| x == {})", stripped)
144                 } else {
145                     format!(".any(|x| x == *{})", s)
146                 }
147             },
148         }
149     }
150     fn get_suggestion_text(&self) -> &'static str {
151         match &self.func {
152             IterFunctionKind::IntoIter => {
153                 "use the original Iterator instead of collecting it and then producing a new one"
154             },
155             IterFunctionKind::Len => {
156                 "take the original Iterator's count instead of collecting it and finding the length"
157             },
158             IterFunctionKind::IsEmpty => {
159                 "check if the original Iterator has anything instead of collecting it and seeing if it's empty"
160             },
161             IterFunctionKind::Contains(_) => {
162                 "check if the original Iterator contains an element instead of collecting then checking"
163             },
164         }
165     }
166 }
167 enum IterFunctionKind {
168     IntoIter,
169     Len,
170     IsEmpty,
171     Contains(Span),
172 }
173
174 struct IterFunctionVisitor<'a, 'tcx> {
175     illegal_mutable_capture_ids: HirIdSet,
176     current_mutably_captured_ids: HirIdSet,
177     cx: &'a LateContext<'tcx>,
178     uses: Vec<Option<IterFunction>>,
179     hir_id_uses_map: FxHashMap<HirId, usize>,
180     current_statement_hir_id: Option<HirId>,
181     seen_other: bool,
182     target: HirId,
183 }
184 impl<'tcx> Visitor<'tcx> for IterFunctionVisitor<'_, 'tcx> {
185     fn visit_block(&mut self, block: &'tcx Block<'tcx>) {
186         for (expr, hir_id) in block.stmts.iter().filter_map(get_expr_and_hir_id_from_stmt) {
187             self.visit_block_expr(expr, hir_id);
188         }
189         if let Some(expr) = block.expr {
190             self.visit_block_expr(expr, None);
191         }
192     }
193
194     fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
195         // Check function calls on our collection
196         if let ExprKind::MethodCall(method_name, _, [recv, args @ ..], _) = &expr.kind {
197             if method_name.ident.name == sym!(collect) && is_trait_method(self.cx, expr, sym::Iterator) {
198                 self.current_mutably_captured_ids = get_captured_ids(self.cx, self.cx.typeck_results().expr_ty(recv));
199                 self.visit_expr(recv);
200                 return;
201             }
202
203             if path_to_local_id(recv, self.target) {
204                 if self
205                     .illegal_mutable_capture_ids
206                     .intersection(&self.current_mutably_captured_ids)
207                     .next()
208                     .is_none()
209                 {
210                     if let Some(hir_id) = self.current_statement_hir_id {
211                         self.hir_id_uses_map.insert(hir_id, self.uses.len());
212                     }
213                     match method_name.ident.name.as_str() {
214                         "into_iter" => self.uses.push(Some(IterFunction {
215                             func: IterFunctionKind::IntoIter,
216                             span: expr.span,
217                         })),
218                         "len" => self.uses.push(Some(IterFunction {
219                             func: IterFunctionKind::Len,
220                             span: expr.span,
221                         })),
222                         "is_empty" => self.uses.push(Some(IterFunction {
223                             func: IterFunctionKind::IsEmpty,
224                             span: expr.span,
225                         })),
226                         "contains" => self.uses.push(Some(IterFunction {
227                             func: IterFunctionKind::Contains(args[0].span),
228                             span: expr.span,
229                         })),
230                         _ => {
231                             self.seen_other = true;
232                             if let Some(hir_id) = self.current_statement_hir_id {
233                                 self.hir_id_uses_map.remove(&hir_id);
234                             }
235                         },
236                     }
237                 }
238                 return;
239             }
240
241             if let Some(hir_id) = path_to_local(recv) {
242                 if let Some(index) = self.hir_id_uses_map.remove(&hir_id) {
243                     if self
244                         .illegal_mutable_capture_ids
245                         .intersection(&self.current_mutably_captured_ids)
246                         .next()
247                         .is_none()
248                     {
249                         if let Some(hir_id) = self.current_statement_hir_id {
250                             self.hir_id_uses_map.insert(hir_id, index);
251                         }
252                     } else {
253                         self.uses[index] = None;
254                     }
255                 }
256             }
257         }
258         // Check if the collection is used for anything else
259         if path_to_local_id(expr, self.target) {
260             self.seen_other = true;
261         } else {
262             walk_expr(self, expr);
263         }
264     }
265
266     type Map = Map<'tcx>;
267     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
268         NestedVisitorMap::None
269     }
270 }
271
272 impl<'tcx> IterFunctionVisitor<'_, 'tcx> {
273     fn visit_block_expr(&mut self, expr: &'tcx Expr<'tcx>, hir_id: Option<HirId>) {
274         self.current_statement_hir_id = hir_id;
275         self.current_mutably_captured_ids = get_captured_ids(self.cx, self.cx.typeck_results().expr_ty(expr));
276         self.visit_expr(expr);
277     }
278 }
279
280 fn get_expr_and_hir_id_from_stmt<'v>(stmt: &'v Stmt<'v>) -> Option<(&'v Expr<'v>, Option<HirId>)> {
281     match stmt.kind {
282         StmtKind::Expr(expr) | StmtKind::Semi(expr) => Some((expr, None)),
283         StmtKind::Item(..) => None,
284         StmtKind::Local(Local { init, pat, .. }) => {
285             if let PatKind::Binding(_, hir_id, ..) = pat.kind {
286                 init.map(|init_expr| (init_expr, Some(hir_id)))
287             } else {
288                 init.map(|init_expr| (init_expr, None))
289             }
290         },
291     }
292 }
293
294 struct UsedCountVisitor<'a, 'tcx> {
295     cx: &'a LateContext<'tcx>,
296     id: HirId,
297     count: usize,
298 }
299
300 impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
301     type Map = Map<'tcx>;
302
303     fn visit_expr(&mut self, expr: &'tcx Expr<'_>) {
304         if path_to_local_id(expr, self.id) {
305             self.count += 1;
306         } else {
307             walk_expr(self, expr);
308         }
309     }
310
311     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
312         NestedVisitorMap::OnlyBodies(self.cx.tcx.hir())
313     }
314 }
315
316 /// Detect the occurrences of calls to `iter` or `into_iter` for the
317 /// given identifier
318 fn detect_iter_and_into_iters<'tcx: 'a, 'a>(
319     block: &'tcx Block<'tcx>,
320     id: HirId,
321     cx: &'a LateContext<'tcx>,
322     captured_ids: HirIdSet,
323 ) -> Option<Vec<IterFunction>> {
324     let mut visitor = IterFunctionVisitor {
325         uses: Vec::new(),
326         target: id,
327         seen_other: false,
328         cx,
329         current_mutably_captured_ids: HirIdSet::default(),
330         illegal_mutable_capture_ids: captured_ids,
331         hir_id_uses_map: FxHashMap::default(),
332         current_statement_hir_id: None,
333     };
334     visitor.visit_block(block);
335     if visitor.seen_other {
336         None
337     } else {
338         Some(visitor.uses.into_iter().flatten().collect())
339     }
340 }
341
342 fn get_captured_ids(cx: &LateContext<'_>, ty: &'_ TyS<'_>) -> HirIdSet {
343     fn get_captured_ids_recursive(cx: &LateContext<'_>, ty: &'_ TyS<'_>, set: &mut HirIdSet) {
344         match ty.kind() {
345             ty::Adt(_, generics) => {
346                 for generic in *generics {
347                     if let GenericArgKind::Type(ty) = generic.unpack() {
348                         get_captured_ids_recursive(cx, ty, set);
349                     }
350                 }
351             },
352             ty::Closure(def_id, _) => {
353                 let closure_hir_node = cx.tcx.hir().get_if_local(*def_id).unwrap();
354                 if let Node::Expr(closure_expr) = closure_hir_node {
355                     can_move_expr_to_closure(cx, closure_expr)
356                         .unwrap()
357                         .into_iter()
358                         .for_each(|(hir_id, capture_kind)| {
359                             if matches!(capture_kind, CaptureKind::Ref(Mutability::Mut)) {
360                                 set.insert(hir_id);
361                             }
362                         });
363                 }
364             },
365             _ => (),
366         }
367     }
368
369     let mut set = HirIdSet::default();
370
371     get_captured_ids_recursive(cx, ty, &mut set);
372
373     set
374 }