]> git.lizzy.rs Git - rust.git/blob - clippy_lints/src/matches/match_same_arms.rs
adding spell checking
[rust.git] / clippy_lints / src / matches / match_same_arms.rs
1 use clippy_utils::diagnostics::span_lint_and_then;
2 use clippy_utils::source::snippet;
3 use clippy_utils::{path_to_local, search_same, SpanlessEq, SpanlessHash};
4 use core::cmp::Ordering;
5 use core::iter;
6 use core::slice;
7 use rustc_arena::DroplessArena;
8 use rustc_ast::ast::LitKind;
9 use rustc_errors::Applicability;
10 use rustc_hir::def_id::DefId;
11 use rustc_hir::{Arm, Expr, ExprKind, HirId, HirIdMap, HirIdSet, Pat, PatKind, RangeEnd};
12 use rustc_lint::LateContext;
13 use rustc_middle::ty;
14 use rustc_span::Symbol;
15 use std::collections::hash_map::Entry;
16
17 use super::MATCH_SAME_ARMS;
18
19 #[allow(clippy::too_many_lines)]
20 pub(super) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) {
21     let hash = |&(_, arm): &(usize, &Arm<'_>)| -> u64 {
22         let mut h = SpanlessHash::new(cx);
23         h.hash_expr(arm.body);
24         h.finish()
25     };
26
27     let arena = DroplessArena::default();
28     let normalized_pats: Vec<_> = arms
29         .iter()
30         .map(|a| NormalizedPat::from_pat(cx, &arena, a.pat))
31         .collect();
32
33     // The furthest forwards a pattern can move without semantic changes
34     let forwards_blocking_idxs: Vec<_> = normalized_pats
35         .iter()
36         .enumerate()
37         .map(|(i, pat)| {
38             normalized_pats[i + 1..]
39                 .iter()
40                 .enumerate()
41                 .find_map(|(j, other)| pat.has_overlapping_values(other).then(|| i + 1 + j))
42                 .unwrap_or(normalized_pats.len())
43         })
44         .collect();
45
46     // The furthest backwards a pattern can move without semantic changes
47     let backwards_blocking_idxs: Vec<_> = normalized_pats
48         .iter()
49         .enumerate()
50         .map(|(i, pat)| {
51             normalized_pats[..i]
52                 .iter()
53                 .enumerate()
54                 .rev()
55                 .zip(forwards_blocking_idxs[..i].iter().copied().rev())
56                 .skip_while(|&(_, forward_block)| forward_block > i)
57                 .find_map(|((j, other), forward_block)| {
58                     (forward_block == i || pat.has_overlapping_values(other)).then(|| j)
59                 })
60                 .unwrap_or(0)
61         })
62         .collect();
63
64     let eq = |&(lindex, lhs): &(usize, &Arm<'_>), &(rindex, rhs): &(usize, &Arm<'_>)| -> bool {
65         let min_index = usize::min(lindex, rindex);
66         let max_index = usize::max(lindex, rindex);
67
68         let mut local_map: HirIdMap<HirId> = HirIdMap::default();
69         let eq_fallback = |a: &Expr<'_>, b: &Expr<'_>| {
70             if_chain! {
71                 if let Some(a_id) = path_to_local(a);
72                 if let Some(b_id) = path_to_local(b);
73                 let entry = match local_map.entry(a_id) {
74                     Entry::Vacant(entry) => entry,
75                     // check if using the same bindings as before
76                     Entry::Occupied(entry) => return *entry.get() == b_id,
77                 };
78                 // the names technically don't have to match; this makes the lint more conservative
79                 if cx.tcx.hir().name(a_id) == cx.tcx.hir().name(b_id);
80                 if cx.typeck_results().expr_ty(a) == cx.typeck_results().expr_ty(b);
81                 if pat_contains_local(lhs.pat, a_id);
82                 if pat_contains_local(rhs.pat, b_id);
83                 then {
84                     entry.insert(b_id);
85                     true
86                 } else {
87                     false
88                 }
89             }
90         };
91         // Arms with a guard are ignored, those can’t always be merged together
92         // If both arms overlap with an arm in between then these can't be merged either.
93         !(backwards_blocking_idxs[max_index] > min_index && forwards_blocking_idxs[min_index] < max_index)
94                 && lhs.guard.is_none()
95                 && rhs.guard.is_none()
96                 && SpanlessEq::new(cx)
97                     .expr_fallback(eq_fallback)
98                     .eq_expr(lhs.body, rhs.body)
99                 // these checks could be removed to allow unused bindings
100                 && bindings_eq(lhs.pat, local_map.keys().copied().collect())
101                 && bindings_eq(rhs.pat, local_map.values().copied().collect())
102     };
103
104     let indexed_arms: Vec<(usize, &Arm<'_>)> = arms.iter().enumerate().collect();
105     for (&(i, arm1), &(j, arm2)) in search_same(&indexed_arms, hash, eq) {
106         if matches!(arm2.pat.kind, PatKind::Wild) {
107             span_lint_and_then(
108                 cx,
109                 MATCH_SAME_ARMS,
110                 arm1.span,
111                 "this match arm has an identical body to the `_` wildcard arm",
112                 |diag| {
113                     diag.span_suggestion(
114                         arm1.span,
115                         "try removing the arm",
116                         String::new(),
117                         Applicability::MaybeIncorrect,
118                     )
119                     .help("or try changing either arm body")
120                     .span_note(arm2.span, "`_` wildcard arm here");
121                 },
122             );
123         } else {
124             let back_block = backwards_blocking_idxs[j];
125             let (keep_arm, move_arm) = if back_block < i || (back_block == 0 && forwards_blocking_idxs[i] <= j) {
126                 (arm1, arm2)
127             } else {
128                 (arm2, arm1)
129             };
130
131             span_lint_and_then(
132                 cx,
133                 MATCH_SAME_ARMS,
134                 keep_arm.span,
135                 "this match arm has an identical body to another arm",
136                 |diag| {
137                     let move_pat_snip = snippet(cx, move_arm.pat.span, "<pat2>");
138                     let keep_pat_snip = snippet(cx, keep_arm.pat.span, "<pat1>");
139
140                     diag.span_suggestion(
141                         keep_arm.pat.span,
142                         "try merging the arm patterns",
143                         format!("{} | {}", keep_pat_snip, move_pat_snip),
144                         Applicability::MaybeIncorrect,
145                     )
146                     .help("or try changing either arm body")
147                     .span_note(move_arm.span, "other arm here");
148                 },
149             );
150         }
151     }
152 }
153
154 #[derive(Clone, Copy)]
155 enum NormalizedPat<'a> {
156     Wild,
157     Struct(Option<DefId>, &'a [(Symbol, Self)]),
158     Tuple(Option<DefId>, &'a [Self]),
159     Or(&'a [Self]),
160     Path(Option<DefId>),
161     LitStr(Symbol),
162     LitBytes(&'a [u8]),
163     LitInt(u128),
164     LitBool(bool),
165     Range(PatRange),
166     /// A slice pattern. If the second value is `None`, then this matches an exact size. Otherwise
167     /// the first value contains everything before the `..` wildcard pattern, and the second value
168     /// contains everything afterwards. Note that either side, or both sides, may contain zero
169     /// patterns.
170     Slice(&'a [Self], Option<&'a [Self]>),
171 }
172
173 #[derive(Clone, Copy)]
174 struct PatRange {
175     start: u128,
176     end: u128,
177     bounds: RangeEnd,
178 }
179 impl PatRange {
180     fn contains(&self, x: u128) -> bool {
181         x >= self.start
182             && match self.bounds {
183                 RangeEnd::Included => x <= self.end,
184                 RangeEnd::Excluded => x < self.end,
185             }
186     }
187
188     fn overlaps(&self, other: &Self) -> bool {
189         // Note: Empty ranges are impossible, so this is correct even though it would return true if an
190         // empty exclusive range were to reside within an inclusive range.
191         (match self.bounds {
192             RangeEnd::Included => self.end >= other.start,
193             RangeEnd::Excluded => self.end > other.start,
194         } && match other.bounds {
195             RangeEnd::Included => self.start <= other.end,
196             RangeEnd::Excluded => self.start < other.end,
197         })
198     }
199 }
200
201 /// Iterates over the pairs of fields with matching names.
202 fn iter_matching_struct_fields<'a>(
203     left: &'a [(Symbol, NormalizedPat<'a>)],
204     right: &'a [(Symbol, NormalizedPat<'a>)],
205 ) -> impl Iterator<Item = (&'a NormalizedPat<'a>, &'a NormalizedPat<'a>)> + 'a {
206     struct Iter<'a>(
207         slice::Iter<'a, (Symbol, NormalizedPat<'a>)>,
208         slice::Iter<'a, (Symbol, NormalizedPat<'a>)>,
209     );
210     impl<'a> Iterator for Iter<'a> {
211         type Item = (&'a NormalizedPat<'a>, &'a NormalizedPat<'a>);
212         fn next(&mut self) -> Option<Self::Item> {
213             // Note: all the fields in each slice are sorted by symbol value.
214             let mut left = self.0.next()?;
215             let mut right = self.1.next()?;
216             loop {
217                 match left.0.cmp(&right.0) {
218                     Ordering::Equal => return Some((&left.1, &right.1)),
219                     Ordering::Less => left = self.0.next()?,
220                     Ordering::Greater => right = self.1.next()?,
221                 }
222             }
223         }
224     }
225     Iter(left.iter(), right.iter())
226 }
227
228 #[allow(clippy::similar_names)]
229 impl<'a> NormalizedPat<'a> {
230     #[allow(clippy::too_many_lines)]
231     fn from_pat(cx: &LateContext<'_>, arena: &'a DroplessArena, pat: &'a Pat<'_>) -> Self {
232         match pat.kind {
233             PatKind::Wild | PatKind::Binding(.., None) => Self::Wild,
234             PatKind::Binding(.., Some(pat)) | PatKind::Box(pat) | PatKind::Ref(pat, _) => {
235                 Self::from_pat(cx, arena, pat)
236             },
237             PatKind::Struct(ref path, fields, _) => {
238                 let fields =
239                     arena.alloc_from_iter(fields.iter().map(|f| (f.ident.name, Self::from_pat(cx, arena, f.pat))));
240                 fields.sort_by_key(|&(name, _)| name);
241                 Self::Struct(cx.qpath_res(path, pat.hir_id).opt_def_id(), fields)
242             },
243             PatKind::TupleStruct(ref path, pats, wild_idx) => {
244                 let adt = match cx.typeck_results().pat_ty(pat).ty_adt_def() {
245                     Some(x) => x,
246                     None => return Self::Wild,
247                 };
248                 let (var_id, variant) = if adt.is_enum() {
249                     match cx.qpath_res(path, pat.hir_id).opt_def_id() {
250                         Some(x) => (Some(x), adt.variant_with_ctor_id(x)),
251                         None => return Self::Wild,
252                     }
253                 } else {
254                     (None, adt.non_enum_variant())
255                 };
256                 let (front, back) = match wild_idx {
257                     Some(i) => pats.split_at(i),
258                     None => (pats, [].as_slice()),
259                 };
260                 let pats = arena.alloc_from_iter(
261                     front
262                         .iter()
263                         .map(|pat| Self::from_pat(cx, arena, pat))
264                         .chain(iter::repeat_with(|| Self::Wild).take(variant.fields.len() - pats.len()))
265                         .chain(back.iter().map(|pat| Self::from_pat(cx, arena, pat))),
266                 );
267                 Self::Tuple(var_id, pats)
268             },
269             PatKind::Or(pats) => Self::Or(arena.alloc_from_iter(pats.iter().map(|pat| Self::from_pat(cx, arena, pat)))),
270             PatKind::Path(ref path) => Self::Path(cx.qpath_res(path, pat.hir_id).opt_def_id()),
271             PatKind::Tuple(pats, wild_idx) => {
272                 let field_count = match cx.typeck_results().pat_ty(pat).kind() {
273                     ty::Tuple(subs) => subs.len(),
274                     _ => return Self::Wild,
275                 };
276                 let (front, back) = match wild_idx {
277                     Some(i) => pats.split_at(i),
278                     None => (pats, [].as_slice()),
279                 };
280                 let pats = arena.alloc_from_iter(
281                     front
282                         .iter()
283                         .map(|pat| Self::from_pat(cx, arena, pat))
284                         .chain(iter::repeat_with(|| Self::Wild).take(field_count - pats.len()))
285                         .chain(back.iter().map(|pat| Self::from_pat(cx, arena, pat))),
286                 );
287                 Self::Tuple(None, pats)
288             },
289             PatKind::Lit(e) => match &e.kind {
290                 // TODO: Handle negative integers. They're currently treated as a wild match.
291                 ExprKind::Lit(lit) => match lit.node {
292                     LitKind::Str(sym, _) => Self::LitStr(sym),
293                     LitKind::ByteStr(ref bytes) => Self::LitBytes(&**bytes),
294                     LitKind::Byte(val) => Self::LitInt(val.into()),
295                     LitKind::Char(val) => Self::LitInt(val.into()),
296                     LitKind::Int(val, _) => Self::LitInt(val),
297                     LitKind::Bool(val) => Self::LitBool(val),
298                     LitKind::Float(..) | LitKind::Err(_) => Self::Wild,
299                 },
300                 _ => Self::Wild,
301             },
302             PatKind::Range(start, end, bounds) => {
303                 // TODO: Handle negative integers. They're currently treated as a wild match.
304                 let start = match start {
305                     None => 0,
306                     Some(e) => match &e.kind {
307                         ExprKind::Lit(lit) => match lit.node {
308                             LitKind::Int(val, _) => val,
309                             LitKind::Char(val) => val.into(),
310                             LitKind::Byte(val) => val.into(),
311                             _ => return Self::Wild,
312                         },
313                         _ => return Self::Wild,
314                     },
315                 };
316                 let (end, bounds) = match end {
317                     None => (u128::MAX, RangeEnd::Included),
318                     Some(e) => match &e.kind {
319                         ExprKind::Lit(lit) => match lit.node {
320                             LitKind::Int(val, _) => (val, bounds),
321                             LitKind::Char(val) => (val.into(), bounds),
322                             LitKind::Byte(val) => (val.into(), bounds),
323                             _ => return Self::Wild,
324                         },
325                         _ => return Self::Wild,
326                     },
327                 };
328                 Self::Range(PatRange { start, end, bounds })
329             },
330             PatKind::Slice(front, wild_pat, back) => Self::Slice(
331                 arena.alloc_from_iter(front.iter().map(|pat| Self::from_pat(cx, arena, pat))),
332                 wild_pat.map(|_| &*arena.alloc_from_iter(back.iter().map(|pat| Self::from_pat(cx, arena, pat)))),
333             ),
334         }
335     }
336
337     /// Checks if two patterns overlap in the values they can match assuming they are for the same
338     /// type.
339     fn has_overlapping_values(&self, other: &Self) -> bool {
340         match (*self, *other) {
341             (Self::Wild, _) | (_, Self::Wild) => true,
342             (Self::Or(pats), ref other) | (ref other, Self::Or(pats)) => {
343                 pats.iter().any(|pat| pat.has_overlapping_values(other))
344             },
345             (Self::Struct(lpath, lfields), Self::Struct(rpath, rfields)) => {
346                 if lpath != rpath {
347                     return false;
348                 }
349                 iter_matching_struct_fields(lfields, rfields).all(|(lpat, rpat)| lpat.has_overlapping_values(rpat))
350             },
351             (Self::Tuple(lpath, lpats), Self::Tuple(rpath, rpats)) => {
352                 if lpath != rpath {
353                     return false;
354                 }
355                 lpats
356                     .iter()
357                     .zip(rpats.iter())
358                     .all(|(lpat, rpat)| lpat.has_overlapping_values(rpat))
359             },
360             (Self::Path(x), Self::Path(y)) => x == y,
361             (Self::LitStr(x), Self::LitStr(y)) => x == y,
362             (Self::LitBytes(x), Self::LitBytes(y)) => x == y,
363             (Self::LitInt(x), Self::LitInt(y)) => x == y,
364             (Self::LitBool(x), Self::LitBool(y)) => x == y,
365             (Self::Range(ref x), Self::Range(ref y)) => x.overlaps(y),
366             (Self::Range(ref range), Self::LitInt(x)) | (Self::LitInt(x), Self::Range(ref range)) => range.contains(x),
367             (Self::Slice(lpats, None), Self::Slice(rpats, None)) => {
368                 lpats.len() == rpats.len() && lpats.iter().zip(rpats.iter()).all(|(x, y)| x.has_overlapping_values(y))
369             },
370             (Self::Slice(pats, None), Self::Slice(front, Some(back)))
371             | (Self::Slice(front, Some(back)), Self::Slice(pats, None)) => {
372                 // Here `pats` is an exact size match. If the combined lengths of `front` and `back` are greater
373                 // then the minium length required will be greater than the length of `pats`.
374                 if pats.len() < front.len() + back.len() {
375                     return false;
376                 }
377                 pats[..front.len()]
378                     .iter()
379                     .zip(front.iter())
380                     .chain(pats[pats.len() - back.len()..].iter().zip(back.iter()))
381                     .all(|(x, y)| x.has_overlapping_values(y))
382             },
383             (Self::Slice(lfront, Some(lback)), Self::Slice(rfront, Some(rback))) => lfront
384                 .iter()
385                 .zip(rfront.iter())
386                 .chain(lback.iter().rev().zip(rback.iter().rev()))
387                 .all(|(x, y)| x.has_overlapping_values(y)),
388
389             // Enums can mix unit variants with tuple/struct variants. These can never overlap.
390             (Self::Path(_), Self::Tuple(..) | Self::Struct(..))
391             | (Self::Tuple(..) | Self::Struct(..), Self::Path(_)) => false,
392
393             // Tuples can be matched like a struct.
394             (Self::Tuple(x, _), Self::Struct(y, _)) | (Self::Struct(x, _), Self::Tuple(y, _)) => {
395                 // TODO: check fields here.
396                 x == y
397             },
398
399             // TODO: Lit* with Path, Range with Path, LitBytes with Slice
400             _ => true,
401         }
402     }
403 }
404
405 fn pat_contains_local(pat: &Pat<'_>, id: HirId) -> bool {
406     let mut result = false;
407     pat.walk_short(|p| {
408         result |= matches!(p.kind, PatKind::Binding(_, binding_id, ..) if binding_id == id);
409         !result
410     });
411     result
412 }
413
414 /// Returns true if all the bindings in the `Pat` are in `ids` and vice versa
415 fn bindings_eq(pat: &Pat<'_>, mut ids: HirIdSet) -> bool {
416     let mut result = true;
417     pat.each_binding_or_first(&mut |_, id, _, _| result &= ids.remove(&id));
418     result && ids.is_empty()
419 }