]> git.lizzy.rs Git - rust.git/blob - clippy_lints/src/loops/needless_collect.rs
92560c806295ccc01e9397f6f244637a920d60c3
[rust.git] / clippy_lints / src / loops / needless_collect.rs
1 use super::NEEDLESS_COLLECT;
2 use crate::utils::sugg::Sugg;
3 use crate::utils::{
4     is_type_diagnostic_item, match_trait_method, match_type, path_to_local_id, paths, snippet, span_lint_and_sugg,
5     span_lint_and_then,
6 };
7 use if_chain::if_chain;
8 use rustc_errors::Applicability;
9 use rustc_hir::intravisit::{walk_block, walk_expr, NestedVisitorMap, Visitor};
10 use rustc_hir::{Block, Expr, ExprKind, GenericArg, HirId, Local, Pat, PatKind, QPath, StmtKind};
11 use rustc_lint::LateContext;
12 use rustc_middle::hir::map::Map;
13 use rustc_span::source_map::Span;
14 use rustc_span::symbol::{sym, Ident};
15
16 const NEEDLESS_COLLECT_MSG: &str = "avoid using `collect()` when not needed";
17
18 pub(super) fn check<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
19     check_needless_collect_direct_usage(expr, cx);
20     check_needless_collect_indirect_usage(expr, cx);
21 }
22 fn check_needless_collect_direct_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
23     if_chain! {
24         if let ExprKind::MethodCall(ref method, _, ref args, _) = expr.kind;
25         if let ExprKind::MethodCall(ref chain_method, _, _, _) = args[0].kind;
26         if chain_method.ident.name == sym!(collect) && match_trait_method(cx, &args[0], &paths::ITERATOR);
27         if let Some(ref generic_args) = chain_method.args;
28         if let Some(GenericArg::Type(ref ty)) = generic_args.args.get(0);
29         then {
30             let ty = cx.typeck_results().node_type(ty.hir_id);
31             if is_type_diagnostic_item(cx, ty, sym::vec_type) ||
32                 is_type_diagnostic_item(cx, ty, sym::vecdeque_type) ||
33                 match_type(cx, ty, &paths::BTREEMAP) ||
34                 is_type_diagnostic_item(cx, ty, sym::hashmap_type) {
35                 if method.ident.name == sym!(len) {
36                     let span = shorten_needless_collect_span(expr);
37                     span_lint_and_sugg(
38                         cx,
39                         NEEDLESS_COLLECT,
40                         span,
41                         NEEDLESS_COLLECT_MSG,
42                         "replace with",
43                         "count()".to_string(),
44                         Applicability::MachineApplicable,
45                     );
46                 }
47                 if method.ident.name == sym!(is_empty) {
48                     let span = shorten_needless_collect_span(expr);
49                     span_lint_and_sugg(
50                         cx,
51                         NEEDLESS_COLLECT,
52                         span,
53                         NEEDLESS_COLLECT_MSG,
54                         "replace with",
55                         "next().is_none()".to_string(),
56                         Applicability::MachineApplicable,
57                     );
58                 }
59                 if method.ident.name == sym!(contains) {
60                     let contains_arg = snippet(cx, args[1].span, "??");
61                     let span = shorten_needless_collect_span(expr);
62                     span_lint_and_then(
63                         cx,
64                         NEEDLESS_COLLECT,
65                         span,
66                         NEEDLESS_COLLECT_MSG,
67                         |diag| {
68                             let (arg, pred) = contains_arg
69                                     .strip_prefix('&')
70                                     .map_or(("&x", &*contains_arg), |s| ("x", s));
71                             diag.span_suggestion(
72                                 span,
73                                 "replace with",
74                                 format!(
75                                     "any(|{}| x == {})",
76                                     arg, pred
77                                 ),
78                                 Applicability::MachineApplicable,
79                             );
80                         }
81                     );
82                 }
83             }
84         }
85     }
86 }
87
88 fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
89     if let ExprKind::Block(ref block, _) = expr.kind {
90         for ref stmt in block.stmts {
91             if_chain! {
92                 if let StmtKind::Local(
93                     Local { pat: Pat { hir_id: pat_id, kind: PatKind::Binding(_, _, ident, .. ), .. },
94                     init: Some(ref init_expr), .. }
95                 ) = stmt.kind;
96                 if let ExprKind::MethodCall(ref method_name, _, &[ref iter_source], ..) = init_expr.kind;
97                 if method_name.ident.name == sym!(collect) && match_trait_method(cx, &init_expr, &paths::ITERATOR);
98                 if let Some(ref generic_args) = method_name.args;
99                 if let Some(GenericArg::Type(ref ty)) = generic_args.args.get(0);
100                 if let ty = cx.typeck_results().node_type(ty.hir_id);
101                 if is_type_diagnostic_item(cx, ty, sym::vec_type) ||
102                     is_type_diagnostic_item(cx, ty, sym::vecdeque_type) ||
103                     match_type(cx, ty, &paths::LINKED_LIST);
104                 if let Some(iter_calls) = detect_iter_and_into_iters(block, *ident);
105                 if iter_calls.len() == 1;
106                 then {
107                     let mut used_count_visitor = UsedCountVisitor {
108                         cx,
109                         id: *pat_id,
110                         count: 0,
111                     };
112                     walk_block(&mut used_count_visitor, block);
113                     if used_count_visitor.count > 1 {
114                         return;
115                     }
116
117                     // Suggest replacing iter_call with iter_replacement, and removing stmt
118                     let iter_call = &iter_calls[0];
119                     span_lint_and_then(
120                         cx,
121                         super::NEEDLESS_COLLECT,
122                         stmt.span.until(iter_call.span),
123                         NEEDLESS_COLLECT_MSG,
124                         |diag| {
125                             let iter_replacement = format!("{}{}", Sugg::hir(cx, iter_source, ".."), iter_call.get_iter_method(cx));
126                             diag.multipart_suggestion(
127                                 iter_call.get_suggestion_text(),
128                                 vec![
129                                     (stmt.span, String::new()),
130                                     (iter_call.span, iter_replacement)
131                                 ],
132                                 Applicability::MachineApplicable,// MaybeIncorrect,
133                             ).emit();
134                         },
135                     );
136                 }
137             }
138         }
139     }
140 }
141
142 struct IterFunction {
143     func: IterFunctionKind,
144     span: Span,
145 }
146 impl IterFunction {
147     fn get_iter_method(&self, cx: &LateContext<'_>) -> String {
148         match &self.func {
149             IterFunctionKind::IntoIter => String::new(),
150             IterFunctionKind::Len => String::from(".count()"),
151             IterFunctionKind::IsEmpty => String::from(".next().is_none()"),
152             IterFunctionKind::Contains(span) => {
153                 let s = snippet(cx, *span, "..");
154                 if let Some(stripped) = s.strip_prefix('&') {
155                     format!(".any(|x| x == {})", stripped)
156                 } else {
157                     format!(".any(|x| x == *{})", s)
158                 }
159             },
160         }
161     }
162     fn get_suggestion_text(&self) -> &'static str {
163         match &self.func {
164             IterFunctionKind::IntoIter => {
165                 "use the original Iterator instead of collecting it and then producing a new one"
166             },
167             IterFunctionKind::Len => {
168                 "take the original Iterator's count instead of collecting it and finding the length"
169             },
170             IterFunctionKind::IsEmpty => {
171                 "check if the original Iterator has anything instead of collecting it and seeing if it's empty"
172             },
173             IterFunctionKind::Contains(_) => {
174                 "check if the original Iterator contains an element instead of collecting then checking"
175             },
176         }
177     }
178 }
179 enum IterFunctionKind {
180     IntoIter,
181     Len,
182     IsEmpty,
183     Contains(Span),
184 }
185
186 struct IterFunctionVisitor {
187     uses: Vec<IterFunction>,
188     seen_other: bool,
189     target: Ident,
190 }
191 impl<'tcx> Visitor<'tcx> for IterFunctionVisitor {
192     fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
193         // Check function calls on our collection
194         if_chain! {
195             if let ExprKind::MethodCall(method_name, _, ref args, _) = &expr.kind;
196             if let Some(Expr { kind: ExprKind::Path(QPath::Resolved(_, ref path)), .. }) = args.get(0);
197             if let &[name] = &path.segments;
198             if name.ident == self.target;
199             then {
200                 let len = sym!(len);
201                 let is_empty = sym!(is_empty);
202                 let contains = sym!(contains);
203                 match method_name.ident.name {
204                     sym::into_iter => self.uses.push(
205                         IterFunction { func: IterFunctionKind::IntoIter, span: expr.span }
206                     ),
207                     name if name == len => self.uses.push(
208                         IterFunction { func: IterFunctionKind::Len, span: expr.span }
209                     ),
210                     name if name == is_empty => self.uses.push(
211                         IterFunction { func: IterFunctionKind::IsEmpty, span: expr.span }
212                     ),
213                     name if name == contains => self.uses.push(
214                         IterFunction { func: IterFunctionKind::Contains(args[1].span), span: expr.span }
215                     ),
216                     _ => self.seen_other = true,
217                 }
218                 return
219             }
220         }
221         // Check if the collection is used for anything else
222         if_chain! {
223             if let Expr { kind: ExprKind::Path(QPath::Resolved(_, ref path)), .. } = expr;
224             if let &[name] = &path.segments;
225             if name.ident == self.target;
226             then {
227                 self.seen_other = true;
228             } else {
229                 walk_expr(self, expr);
230             }
231         }
232     }
233
234     type Map = Map<'tcx>;
235     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
236         NestedVisitorMap::None
237     }
238 }
239
240 struct UsedCountVisitor<'a, 'tcx> {
241     cx: &'a LateContext<'tcx>,
242     id: HirId,
243     count: usize,
244 }
245
246 impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
247     type Map = Map<'tcx>;
248
249     fn visit_expr(&mut self, expr: &'tcx Expr<'_>) {
250         if path_to_local_id(expr, self.id) {
251             self.count += 1;
252         } else {
253             walk_expr(self, expr);
254         }
255     }
256
257     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
258         NestedVisitorMap::OnlyBodies(self.cx.tcx.hir())
259     }
260 }
261
262 /// Detect the occurrences of calls to `iter` or `into_iter` for the
263 /// given identifier
264 fn detect_iter_and_into_iters<'tcx>(block: &'tcx Block<'tcx>, identifier: Ident) -> Option<Vec<IterFunction>> {
265     let mut visitor = IterFunctionVisitor {
266         uses: Vec::new(),
267         target: identifier,
268         seen_other: false,
269     };
270     visitor.visit_block(block);
271     if visitor.seen_other { None } else { Some(visitor.uses) }
272 }
273
274 fn shorten_needless_collect_span(expr: &Expr<'_>) -> Span {
275     if_chain! {
276         if let ExprKind::MethodCall(.., args, _) = &expr.kind;
277         if let ExprKind::MethodCall(_, span, ..) = &args[0].kind;
278         then {
279             return expr.span.with_lo(span.lo());
280         }
281     }
282     unreachable!();
283 }