]> git.lizzy.rs Git - rust.git/blob - clippy_lints/src/loops/needless_collect.rs
Add check_needless_collect to its own module
[rust.git] / clippy_lints / src / loops / needless_collect.rs
1 use crate::utils::sugg::Sugg;
2 use crate::utils::{
3     is_type_diagnostic_item, match_trait_method, match_type, path_to_local_id, paths, snippet, span_lint_and_sugg,
4     span_lint_and_then,
5 };
6 use if_chain::if_chain;
7 use rustc_errors::Applicability;
8 use rustc_hir::intravisit::{walk_block, walk_expr, NestedVisitorMap, Visitor};
9 use rustc_hir::{Block, Expr, ExprKind, GenericArg, HirId, Local, Pat, PatKind, QPath, StmtKind};
10 use rustc_lint::LateContext;
11 use rustc_middle::hir::map::Map;
12 use rustc_span::source_map::Span;
13 use rustc_span::symbol::{sym, Ident};
14
15 const NEEDLESS_COLLECT_MSG: &str = "avoid using `collect()` when not needed";
16
17 pub(super) fn check_needless_collect<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
18     check_needless_collect_direct_usage(expr, cx);
19     check_needless_collect_indirect_usage(expr, cx);
20 }
21 fn check_needless_collect_direct_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
22     if_chain! {
23         if let ExprKind::MethodCall(ref method, _, ref args, _) = expr.kind;
24         if let ExprKind::MethodCall(ref chain_method, _, _, _) = args[0].kind;
25         if chain_method.ident.name == sym!(collect) && match_trait_method(cx, &args[0], &paths::ITERATOR);
26         if let Some(ref generic_args) = chain_method.args;
27         if let Some(GenericArg::Type(ref ty)) = generic_args.args.get(0);
28         then {
29             let ty = cx.typeck_results().node_type(ty.hir_id);
30             if is_type_diagnostic_item(cx, ty, sym::vec_type) ||
31                 is_type_diagnostic_item(cx, ty, sym!(vecdeque_type)) ||
32                 match_type(cx, ty, &paths::BTREEMAP) ||
33                 is_type_diagnostic_item(cx, ty, sym!(hashmap_type)) {
34                 if method.ident.name == sym!(len) {
35                     let span = shorten_needless_collect_span(expr);
36                     span_lint_and_sugg(
37                         cx,
38                         super::NEEDLESS_COLLECT,
39                         span,
40                         NEEDLESS_COLLECT_MSG,
41                         "replace with",
42                         "count()".to_string(),
43                         Applicability::MachineApplicable,
44                     );
45                 }
46                 if method.ident.name == sym!(is_empty) {
47                     let span = shorten_needless_collect_span(expr);
48                     span_lint_and_sugg(
49                         cx,
50                         super::NEEDLESS_COLLECT,
51                         span,
52                         NEEDLESS_COLLECT_MSG,
53                         "replace with",
54                         "next().is_none()".to_string(),
55                         Applicability::MachineApplicable,
56                     );
57                 }
58                 if method.ident.name == sym!(contains) {
59                     let contains_arg = snippet(cx, args[1].span, "??");
60                     let span = shorten_needless_collect_span(expr);
61                     span_lint_and_then(
62                         cx,
63                         super::NEEDLESS_COLLECT,
64                         span,
65                         NEEDLESS_COLLECT_MSG,
66                         |diag| {
67                             let (arg, pred) = contains_arg
68                                     .strip_prefix('&')
69                                     .map_or(("&x", &*contains_arg), |s| ("x", s));
70                             diag.span_suggestion(
71                                 span,
72                                 "replace with",
73                                 format!(
74                                     "any(|{}| x == {})",
75                                     arg, pred
76                                 ),
77                                 Applicability::MachineApplicable,
78                             );
79                         }
80                     );
81                 }
82             }
83         }
84     }
85 }
86
87 fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
88     if let ExprKind::Block(ref block, _) = expr.kind {
89         for ref stmt in block.stmts {
90             if_chain! {
91                 if let StmtKind::Local(
92                     Local { pat: Pat { hir_id: pat_id, kind: PatKind::Binding(_, _, ident, .. ), .. },
93                     init: Some(ref init_expr), .. }
94                 ) = stmt.kind;
95                 if let ExprKind::MethodCall(ref method_name, _, &[ref iter_source], ..) = init_expr.kind;
96                 if method_name.ident.name == sym!(collect) && match_trait_method(cx, &init_expr, &paths::ITERATOR);
97                 if let Some(ref generic_args) = method_name.args;
98                 if let Some(GenericArg::Type(ref ty)) = generic_args.args.get(0);
99                 if let ty = cx.typeck_results().node_type(ty.hir_id);
100                 if is_type_diagnostic_item(cx, ty, sym::vec_type) ||
101                     is_type_diagnostic_item(cx, ty, sym!(vecdeque_type)) ||
102                     match_type(cx, ty, &paths::LINKED_LIST);
103                 if let Some(iter_calls) = detect_iter_and_into_iters(block, *ident);
104                 if iter_calls.len() == 1;
105                 then {
106                     let mut used_count_visitor = UsedCountVisitor {
107                         cx,
108                         id: *pat_id,
109                         count: 0,
110                     };
111                     walk_block(&mut used_count_visitor, block);
112                     if used_count_visitor.count > 1 {
113                         return;
114                     }
115
116                     // Suggest replacing iter_call with iter_replacement, and removing stmt
117                     let iter_call = &iter_calls[0];
118                     span_lint_and_then(
119                         cx,
120                         super::NEEDLESS_COLLECT,
121                         stmt.span.until(iter_call.span),
122                         NEEDLESS_COLLECT_MSG,
123                         |diag| {
124                             let iter_replacement = format!("{}{}", Sugg::hir(cx, iter_source, ".."), iter_call.get_iter_method(cx));
125                             diag.multipart_suggestion(
126                                 iter_call.get_suggestion_text(),
127                                 vec![
128                                     (stmt.span, String::new()),
129                                     (iter_call.span, iter_replacement)
130                                 ],
131                                 Applicability::MachineApplicable,// MaybeIncorrect,
132                             ).emit();
133                         },
134                     );
135                 }
136             }
137         }
138     }
139 }
140
141 struct IterFunction {
142     func: IterFunctionKind,
143     span: Span,
144 }
145 impl IterFunction {
146     fn get_iter_method(&self, cx: &LateContext<'_>) -> String {
147         match &self.func {
148             IterFunctionKind::IntoIter => String::new(),
149             IterFunctionKind::Len => String::from(".count()"),
150             IterFunctionKind::IsEmpty => String::from(".next().is_none()"),
151             IterFunctionKind::Contains(span) => {
152                 let s = snippet(cx, *span, "..");
153                 if let Some(stripped) = s.strip_prefix('&') {
154                     format!(".any(|x| x == {})", stripped)
155                 } else {
156                     format!(".any(|x| x == *{})", s)
157                 }
158             },
159         }
160     }
161     fn get_suggestion_text(&self) -> &'static str {
162         match &self.func {
163             IterFunctionKind::IntoIter => {
164                 "use the original Iterator instead of collecting it and then producing a new one"
165             },
166             IterFunctionKind::Len => {
167                 "take the original Iterator's count instead of collecting it and finding the length"
168             },
169             IterFunctionKind::IsEmpty => {
170                 "check if the original Iterator has anything instead of collecting it and seeing if it's empty"
171             },
172             IterFunctionKind::Contains(_) => {
173                 "check if the original Iterator contains an element instead of collecting then checking"
174             },
175         }
176     }
177 }
178 enum IterFunctionKind {
179     IntoIter,
180     Len,
181     IsEmpty,
182     Contains(Span),
183 }
184
185 struct IterFunctionVisitor {
186     uses: Vec<IterFunction>,
187     seen_other: bool,
188     target: Ident,
189 }
190 impl<'tcx> Visitor<'tcx> for IterFunctionVisitor {
191     fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
192         // Check function calls on our collection
193         if_chain! {
194             if let ExprKind::MethodCall(method_name, _, ref args, _) = &expr.kind;
195             if let Some(Expr { kind: ExprKind::Path(QPath::Resolved(_, ref path)), .. }) = args.get(0);
196             if let &[name] = &path.segments;
197             if name.ident == self.target;
198             then {
199                 let len = sym!(len);
200                 let is_empty = sym!(is_empty);
201                 let contains = sym!(contains);
202                 match method_name.ident.name {
203                     sym::into_iter => self.uses.push(
204                         IterFunction { func: IterFunctionKind::IntoIter, span: expr.span }
205                     ),
206                     name if name == len => self.uses.push(
207                         IterFunction { func: IterFunctionKind::Len, span: expr.span }
208                     ),
209                     name if name == is_empty => self.uses.push(
210                         IterFunction { func: IterFunctionKind::IsEmpty, span: expr.span }
211                     ),
212                     name if name == contains => self.uses.push(
213                         IterFunction { func: IterFunctionKind::Contains(args[1].span), span: expr.span }
214                     ),
215                     _ => self.seen_other = true,
216                 }
217                 return
218             }
219         }
220         // Check if the collection is used for anything else
221         if_chain! {
222             if let Expr { kind: ExprKind::Path(QPath::Resolved(_, ref path)), .. } = expr;
223             if let &[name] = &path.segments;
224             if name.ident == self.target;
225             then {
226                 self.seen_other = true;
227             } else {
228                 walk_expr(self, expr);
229             }
230         }
231     }
232
233     type Map = Map<'tcx>;
234     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
235         NestedVisitorMap::None
236     }
237 }
238
239 struct UsedCountVisitor<'a, 'tcx> {
240     cx: &'a LateContext<'tcx>,
241     id: HirId,
242     count: usize,
243 }
244
245 impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
246     type Map = Map<'tcx>;
247
248     fn visit_expr(&mut self, expr: &'tcx Expr<'_>) {
249         if path_to_local_id(expr, self.id) {
250             self.count += 1;
251         } else {
252             walk_expr(self, expr);
253         }
254     }
255
256     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
257         NestedVisitorMap::OnlyBodies(self.cx.tcx.hir())
258     }
259 }
260
261 /// Detect the occurrences of calls to `iter` or `into_iter` for the
262 /// given identifier
263 fn detect_iter_and_into_iters<'tcx>(block: &'tcx Block<'tcx>, identifier: Ident) -> Option<Vec<IterFunction>> {
264     let mut visitor = IterFunctionVisitor {
265         uses: Vec::new(),
266         target: identifier,
267         seen_other: false,
268     };
269     visitor.visit_block(block);
270     if visitor.seen_other { None } else { Some(visitor.uses) }
271 }
272
273 fn shorten_needless_collect_span(expr: &Expr<'_>) -> Span {
274     if_chain! {
275         if let ExprKind::MethodCall(.., args, _) = &expr.kind;
276         if let ExprKind::MethodCall(_, span, ..) = &args[0].kind;
277         then {
278             return expr.span.with_lo(span.lo());
279         }
280     }
281     unreachable!();
282 }