]> git.lizzy.rs Git - rust.git/blob - clippy_lints/src/loops/needless_collect.rs
6a9aa08426c0621ddeaf067986ee012253779746
[rust.git] / clippy_lints / src / loops / needless_collect.rs
1 use super::NEEDLESS_COLLECT;
2 use clippy_utils::diagnostics::{span_lint_and_sugg, span_lint_and_then};
3 use clippy_utils::source::snippet;
4 use clippy_utils::sugg::Sugg;
5 use clippy_utils::ty::is_type_diagnostic_item;
6 use clippy_utils::{is_trait_method, path_to_local_id};
7 use if_chain::if_chain;
8 use rustc_errors::Applicability;
9 use rustc_hir::intravisit::{walk_block, walk_expr, NestedVisitorMap, Visitor};
10 use rustc_hir::{Block, Expr, ExprKind, GenericArg, GenericArgs, HirId, Local, Pat, PatKind, QPath, StmtKind, Ty};
11 use rustc_lint::LateContext;
12 use rustc_middle::hir::map::Map;
13 use rustc_span::symbol::{sym, Ident};
14 use rustc_span::{MultiSpan, Span};
15
16 const NEEDLESS_COLLECT_MSG: &str = "avoid using `collect()` when not needed";
17
18 pub(super) fn check<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
19     check_needless_collect_direct_usage(expr, cx);
20     check_needless_collect_indirect_usage(expr, cx);
21 }
22 fn check_needless_collect_direct_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
23     if_chain! {
24         if let ExprKind::MethodCall(method, _, args, _) = expr.kind;
25         if let ExprKind::MethodCall(chain_method, method0_span, _, _) = args[0].kind;
26         if chain_method.ident.name == sym!(collect) && is_trait_method(cx, &args[0], sym::Iterator);
27         if let Some(generic_args) = chain_method.args;
28         if let Some(GenericArg::Type(ref ty)) = generic_args.args.get(0);
29         if let Some(ty) = cx.typeck_results().node_type_opt(ty.hir_id);
30         then {
31             let is_empty_sugg = Some("next().is_none()".to_string());
32             let method_name = &*method.ident.name.as_str();
33             let sugg = if is_type_diagnostic_item(cx, ty, sym::vec_type) ||
34                         is_type_diagnostic_item(cx, ty, sym::vecdeque_type) {
35                 match method_name {
36                     "len" => Some("count()".to_string()),
37                     "is_empty" => is_empty_sugg,
38                     "contains" => {
39                         let contains_arg = snippet(cx, args[1].span, "??");
40                         let (arg, pred) = contains_arg
41                             .strip_prefix('&')
42                             .map_or(("&x", &*contains_arg), |s| ("x", s));
43                         Some(format!("any(|{}| x == {})", arg, pred))
44                     }
45                     _ => None,
46                 }
47             }
48             else if is_type_diagnostic_item(cx, ty, sym::BTreeMap) ||
49                 is_type_diagnostic_item(cx, ty, sym::hashmap_type) {
50                 match method_name {
51                     "is_empty" => is_empty_sugg,
52                     _ => None,
53                 }
54             }
55             else {
56                 None
57             };
58             if let Some(sugg) = sugg {
59                 span_lint_and_sugg(
60                     cx,
61                     NEEDLESS_COLLECT,
62                     method0_span.with_hi(expr.span.hi()),
63                     NEEDLESS_COLLECT_MSG,
64                     "replace with",
65                     sugg,
66                     Applicability::MachineApplicable,
67                 );
68             }
69         }
70     }
71 }
72
73 fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateContext<'tcx>) {
74     fn get_hir_id<'tcx>(ty: Option<&Ty<'tcx>>, method_args: Option<&GenericArgs<'tcx>>) -> Option<HirId> {
75         if let Some(ty) = ty {
76             return Some(ty.hir_id);
77         }
78
79         if let Some(generic_args) = method_args {
80             if let Some(GenericArg::Type(ref ty)) = generic_args.args.get(0) {
81                 return Some(ty.hir_id);
82             }
83         }
84
85         None
86     }
87     if let ExprKind::Block(block, _) = expr.kind {
88         for stmt in block.stmts {
89             if_chain! {
90                 if let StmtKind::Local(
91                     Local { pat: Pat { hir_id: pat_id, kind: PatKind::Binding(_, _, ident, .. ), .. },
92                     init: Some(init_expr), ty, .. }
93                 ) = stmt.kind;
94                 if let ExprKind::MethodCall(method_name, collect_span, &[ref iter_source], ..) = init_expr.kind;
95                 if method_name.ident.name == sym!(collect) && is_trait_method(cx, init_expr, sym::Iterator);
96                 if let Some(hir_id) = get_hir_id(*ty, method_name.args);
97                 if let Some(ty) = cx.typeck_results().node_type_opt(hir_id);
98                 if is_type_diagnostic_item(cx, ty, sym::vec_type) ||
99                     is_type_diagnostic_item(cx, ty, sym::vecdeque_type) ||
100                     is_type_diagnostic_item(cx, ty, sym::BinaryHeap) ||
101                     is_type_diagnostic_item(cx, ty, sym::LinkedList);
102                 if let Some(iter_calls) = detect_iter_and_into_iters(block, *ident);
103                 if let [iter_call] = &*iter_calls;
104                 then {
105                     let mut used_count_visitor = UsedCountVisitor {
106                         cx,
107                         id: *pat_id,
108                         count: 0,
109                     };
110                     walk_block(&mut used_count_visitor, block);
111                     if used_count_visitor.count > 1 {
112                         return;
113                     }
114
115                     // Suggest replacing iter_call with iter_replacement, and removing stmt
116                     let mut span = MultiSpan::from_span(collect_span);
117                     span.push_span_label(iter_call.span, "the iterator could be used here instead".into());
118                     span_lint_and_then(
119                         cx,
120                         super::NEEDLESS_COLLECT,
121                         span,
122                         NEEDLESS_COLLECT_MSG,
123                         |diag| {
124                             let iter_replacement = format!("{}{}", Sugg::hir(cx, iter_source, ".."), iter_call.get_iter_method(cx));
125                             diag.multipart_suggestion(
126                                 iter_call.get_suggestion_text(),
127                                 vec![
128                                     (stmt.span, String::new()),
129                                     (iter_call.span, iter_replacement)
130                                 ],
131                                 Applicability::MachineApplicable,// MaybeIncorrect,
132                             );
133                         },
134                     );
135                 }
136             }
137         }
138     }
139 }
140
141 struct IterFunction {
142     func: IterFunctionKind,
143     span: Span,
144 }
145 impl IterFunction {
146     fn get_iter_method(&self, cx: &LateContext<'_>) -> String {
147         match &self.func {
148             IterFunctionKind::IntoIter => String::new(),
149             IterFunctionKind::Len => String::from(".count()"),
150             IterFunctionKind::IsEmpty => String::from(".next().is_none()"),
151             IterFunctionKind::Contains(span) => {
152                 let s = snippet(cx, *span, "..");
153                 if let Some(stripped) = s.strip_prefix('&') {
154                     format!(".any(|x| x == {})", stripped)
155                 } else {
156                     format!(".any(|x| x == *{})", s)
157                 }
158             },
159         }
160     }
161     fn get_suggestion_text(&self) -> &'static str {
162         match &self.func {
163             IterFunctionKind::IntoIter => {
164                 "use the original Iterator instead of collecting it and then producing a new one"
165             },
166             IterFunctionKind::Len => {
167                 "take the original Iterator's count instead of collecting it and finding the length"
168             },
169             IterFunctionKind::IsEmpty => {
170                 "check if the original Iterator has anything instead of collecting it and seeing if it's empty"
171             },
172             IterFunctionKind::Contains(_) => {
173                 "check if the original Iterator contains an element instead of collecting then checking"
174             },
175         }
176     }
177 }
178 enum IterFunctionKind {
179     IntoIter,
180     Len,
181     IsEmpty,
182     Contains(Span),
183 }
184
185 struct IterFunctionVisitor {
186     uses: Vec<IterFunction>,
187     seen_other: bool,
188     target: Ident,
189 }
190 impl<'tcx> Visitor<'tcx> for IterFunctionVisitor {
191     fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
192         // Check function calls on our collection
193         if_chain! {
194             if let ExprKind::MethodCall(method_name, _, args, _) = &expr.kind;
195             if let Some(Expr { kind: ExprKind::Path(QPath::Resolved(_, path)), .. }) = args.get(0);
196             if let &[name] = &path.segments;
197             if name.ident == self.target;
198             then {
199                 let len = sym!(len);
200                 let is_empty = sym!(is_empty);
201                 let contains = sym!(contains);
202                 match method_name.ident.name {
203                     sym::into_iter => self.uses.push(
204                         IterFunction { func: IterFunctionKind::IntoIter, span: expr.span }
205                     ),
206                     name if name == len => self.uses.push(
207                         IterFunction { func: IterFunctionKind::Len, span: expr.span }
208                     ),
209                     name if name == is_empty => self.uses.push(
210                         IterFunction { func: IterFunctionKind::IsEmpty, span: expr.span }
211                     ),
212                     name if name == contains => self.uses.push(
213                         IterFunction { func: IterFunctionKind::Contains(args[1].span), span: expr.span }
214                     ),
215                     _ => self.seen_other = true,
216                 }
217                 return
218             }
219         }
220         // Check if the collection is used for anything else
221         if_chain! {
222             if let Expr { kind: ExprKind::Path(QPath::Resolved(_, path)), .. } = expr;
223             if let &[name] = &path.segments;
224             if name.ident == self.target;
225             then {
226                 self.seen_other = true;
227             } else {
228                 walk_expr(self, expr);
229             }
230         }
231     }
232
233     type Map = Map<'tcx>;
234     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
235         NestedVisitorMap::None
236     }
237 }
238
239 struct UsedCountVisitor<'a, 'tcx> {
240     cx: &'a LateContext<'tcx>,
241     id: HirId,
242     count: usize,
243 }
244
245 impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
246     type Map = Map<'tcx>;
247
248     fn visit_expr(&mut self, expr: &'tcx Expr<'_>) {
249         if path_to_local_id(expr, self.id) {
250             self.count += 1;
251         } else {
252             walk_expr(self, expr);
253         }
254     }
255
256     fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
257         NestedVisitorMap::OnlyBodies(self.cx.tcx.hir())
258     }
259 }
260
261 /// Detect the occurrences of calls to `iter` or `into_iter` for the
262 /// given identifier
263 fn detect_iter_and_into_iters<'tcx>(block: &'tcx Block<'tcx>, identifier: Ident) -> Option<Vec<IterFunction>> {
264     let mut visitor = IterFunctionVisitor {
265         uses: Vec::new(),
266         target: identifier,
267         seen_other: false,
268     };
269     visitor.visit_block(block);
270     if visitor.seen_other { None } else { Some(visitor.uses) }
271 }