]>
Commit | Line | Data |
---|---|---|
5099ac24 FG |
1 | use clippy_utils::diagnostics::span_lint_and_then; |
2 | use clippy_utils::source::snippet; | |
3 | use clippy_utils::{path_to_local, search_same, SpanlessEq, SpanlessHash}; | |
5e7ed085 FG |
4 | use core::cmp::Ordering; |
5 | use core::iter; | |
6 | use core::slice; | |
7 | use rustc_arena::DroplessArena; | |
8 | use rustc_ast::ast::LitKind; | |
9 | use rustc_errors::Applicability; | |
10 | use rustc_hir::def_id::DefId; | |
11 | use rustc_hir::{Arm, Expr, ExprKind, HirId, HirIdMap, HirIdSet, Pat, PatKind, RangeEnd}; | |
5099ac24 | 12 | use rustc_lint::LateContext; |
5e7ed085 FG |
13 | use rustc_middle::ty; |
14 | use rustc_span::Symbol; | |
5099ac24 FG |
15 | use std::collections::hash_map::Entry; |
16 | ||
17 | use super::MATCH_SAME_ARMS; | |
18 | ||
923072b8 | 19 | #[expect(clippy::too_many_lines)] |
5e7ed085 FG |
20 | pub(super) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) { |
21 | let hash = |&(_, arm): &(usize, &Arm<'_>)| -> u64 { | |
22 | let mut h = SpanlessHash::new(cx); | |
23 | h.hash_expr(arm.body); | |
24 | h.finish() | |
25 | }; | |
5099ac24 | 26 | |
5e7ed085 FG |
27 | let arena = DroplessArena::default(); |
28 | let normalized_pats: Vec<_> = arms | |
29 | .iter() | |
30 | .map(|a| NormalizedPat::from_pat(cx, &arena, a.pat)) | |
31 | .collect(); | |
32 | ||
04454e1e | 33 | // The furthest forwards a pattern can move without semantic changes |
5e7ed085 FG |
34 | let forwards_blocking_idxs: Vec<_> = normalized_pats |
35 | .iter() | |
36 | .enumerate() | |
37 | .map(|(i, pat)| { | |
38 | normalized_pats[i + 1..] | |
39 | .iter() | |
40 | .enumerate() | |
41 | .find_map(|(j, other)| pat.has_overlapping_values(other).then(|| i + 1 + j)) | |
42 | .unwrap_or(normalized_pats.len()) | |
43 | }) | |
44 | .collect(); | |
45 | ||
04454e1e | 46 | // The furthest backwards a pattern can move without semantic changes |
5e7ed085 FG |
47 | let backwards_blocking_idxs: Vec<_> = normalized_pats |
48 | .iter() | |
49 | .enumerate() | |
50 | .map(|(i, pat)| { | |
51 | normalized_pats[..i] | |
52 | .iter() | |
53 | .enumerate() | |
54 | .rev() | |
55 | .zip(forwards_blocking_idxs[..i].iter().copied().rev()) | |
56 | .skip_while(|&(_, forward_block)| forward_block > i) | |
57 | .find_map(|((j, other), forward_block)| { | |
58 | (forward_block == i || pat.has_overlapping_values(other)).then(|| j) | |
59 | }) | |
60 | .unwrap_or(0) | |
61 | }) | |
62 | .collect(); | |
63 | ||
64 | let eq = |&(lindex, lhs): &(usize, &Arm<'_>), &(rindex, rhs): &(usize, &Arm<'_>)| -> bool { | |
65 | let min_index = usize::min(lindex, rindex); | |
66 | let max_index = usize::max(lindex, rindex); | |
67 | ||
68 | let mut local_map: HirIdMap<HirId> = HirIdMap::default(); | |
69 | let eq_fallback = |a: &Expr<'_>, b: &Expr<'_>| { | |
70 | if_chain! { | |
71 | if let Some(a_id) = path_to_local(a); | |
72 | if let Some(b_id) = path_to_local(b); | |
73 | let entry = match local_map.entry(a_id) { | |
74 | Entry::Vacant(entry) => entry, | |
75 | // check if using the same bindings as before | |
76 | Entry::Occupied(entry) => return *entry.get() == b_id, | |
77 | }; | |
78 | // the names technically don't have to match; this makes the lint more conservative | |
79 | if cx.tcx.hir().name(a_id) == cx.tcx.hir().name(b_id); | |
80 | if cx.typeck_results().expr_ty(a) == cx.typeck_results().expr_ty(b); | |
81 | if pat_contains_local(lhs.pat, a_id); | |
82 | if pat_contains_local(rhs.pat, b_id); | |
83 | then { | |
84 | entry.insert(b_id); | |
85 | true | |
86 | } else { | |
87 | false | |
5099ac24 | 88 | } |
5e7ed085 FG |
89 | } |
90 | }; | |
91 | // Arms with a guard are ignored, those can’t always be merged together | |
92 | // If both arms overlap with an arm in between then these can't be merged either. | |
93 | !(backwards_blocking_idxs[max_index] > min_index && forwards_blocking_idxs[min_index] < max_index) | |
94 | && lhs.guard.is_none() | |
95 | && rhs.guard.is_none() | |
5099ac24 FG |
96 | && SpanlessEq::new(cx) |
97 | .expr_fallback(eq_fallback) | |
98 | .eq_expr(lhs.body, rhs.body) | |
99 | // these checks could be removed to allow unused bindings | |
100 | && bindings_eq(lhs.pat, local_map.keys().copied().collect()) | |
101 | && bindings_eq(rhs.pat, local_map.values().copied().collect()) | |
5e7ed085 | 102 | }; |
5099ac24 | 103 | |
5e7ed085 FG |
104 | let indexed_arms: Vec<(usize, &Arm<'_>)> = arms.iter().enumerate().collect(); |
105 | for (&(i, arm1), &(j, arm2)) in search_same(&indexed_arms, hash, eq) { | |
106 | if matches!(arm2.pat.kind, PatKind::Wild) { | |
5099ac24 FG |
107 | span_lint_and_then( |
108 | cx, | |
109 | MATCH_SAME_ARMS, | |
5e7ed085 FG |
110 | arm1.span, |
111 | "this match arm has an identical body to the `_` wildcard arm", | |
5099ac24 | 112 | |diag| { |
923072b8 FG |
113 | diag.span_suggestion(arm1.span, "try removing the arm", "", Applicability::MaybeIncorrect) |
114 | .help("or try changing either arm body") | |
115 | .span_note(arm2.span, "`_` wildcard arm here"); | |
5e7ed085 FG |
116 | }, |
117 | ); | |
118 | } else { | |
119 | let back_block = backwards_blocking_idxs[j]; | |
120 | let (keep_arm, move_arm) = if back_block < i || (back_block == 0 && forwards_blocking_idxs[i] <= j) { | |
121 | (arm1, arm2) | |
122 | } else { | |
123 | (arm2, arm1) | |
124 | }; | |
125 | ||
126 | span_lint_and_then( | |
127 | cx, | |
128 | MATCH_SAME_ARMS, | |
129 | keep_arm.span, | |
130 | "this match arm has an identical body to another arm", | |
131 | |diag| { | |
132 | let move_pat_snip = snippet(cx, move_arm.pat.span, "<pat2>"); | |
133 | let keep_pat_snip = snippet(cx, keep_arm.pat.span, "<pat1>"); | |
134 | ||
135 | diag.span_suggestion( | |
136 | keep_arm.pat.span, | |
137 | "try merging the arm patterns", | |
138 | format!("{} | {}", keep_pat_snip, move_pat_snip), | |
139 | Applicability::MaybeIncorrect, | |
140 | ) | |
141 | .help("or try changing either arm body") | |
142 | .span_note(move_arm.span, "other arm here"); | |
5099ac24 FG |
143 | }, |
144 | ); | |
145 | } | |
146 | } | |
147 | } | |
148 | ||
5e7ed085 FG |
149 | #[derive(Clone, Copy)] |
150 | enum NormalizedPat<'a> { | |
151 | Wild, | |
152 | Struct(Option<DefId>, &'a [(Symbol, Self)]), | |
153 | Tuple(Option<DefId>, &'a [Self]), | |
154 | Or(&'a [Self]), | |
155 | Path(Option<DefId>), | |
156 | LitStr(Symbol), | |
157 | LitBytes(&'a [u8]), | |
158 | LitInt(u128), | |
159 | LitBool(bool), | |
160 | Range(PatRange), | |
161 | /// A slice pattern. If the second value is `None`, then this matches an exact size. Otherwise | |
162 | /// the first value contains everything before the `..` wildcard pattern, and the second value | |
163 | /// contains everything afterwards. Note that either side, or both sides, may contain zero | |
164 | /// patterns. | |
165 | Slice(&'a [Self], Option<&'a [Self]>), | |
166 | } | |
167 | ||
168 | #[derive(Clone, Copy)] | |
169 | struct PatRange { | |
170 | start: u128, | |
171 | end: u128, | |
172 | bounds: RangeEnd, | |
173 | } | |
174 | impl PatRange { | |
175 | fn contains(&self, x: u128) -> bool { | |
176 | x >= self.start | |
177 | && match self.bounds { | |
178 | RangeEnd::Included => x <= self.end, | |
179 | RangeEnd::Excluded => x < self.end, | |
180 | } | |
181 | } | |
182 | ||
183 | fn overlaps(&self, other: &Self) -> bool { | |
184 | // Note: Empty ranges are impossible, so this is correct even though it would return true if an | |
185 | // empty exclusive range were to reside within an inclusive range. | |
186 | (match self.bounds { | |
187 | RangeEnd::Included => self.end >= other.start, | |
188 | RangeEnd::Excluded => self.end > other.start, | |
189 | } && match other.bounds { | |
190 | RangeEnd::Included => self.start <= other.end, | |
191 | RangeEnd::Excluded => self.start < other.end, | |
192 | }) | |
193 | } | |
194 | } | |
195 | ||
196 | /// Iterates over the pairs of fields with matching names. | |
197 | fn iter_matching_struct_fields<'a>( | |
198 | left: &'a [(Symbol, NormalizedPat<'a>)], | |
199 | right: &'a [(Symbol, NormalizedPat<'a>)], | |
200 | ) -> impl Iterator<Item = (&'a NormalizedPat<'a>, &'a NormalizedPat<'a>)> + 'a { | |
201 | struct Iter<'a>( | |
202 | slice::Iter<'a, (Symbol, NormalizedPat<'a>)>, | |
203 | slice::Iter<'a, (Symbol, NormalizedPat<'a>)>, | |
204 | ); | |
205 | impl<'a> Iterator for Iter<'a> { | |
206 | type Item = (&'a NormalizedPat<'a>, &'a NormalizedPat<'a>); | |
207 | fn next(&mut self) -> Option<Self::Item> { | |
208 | // Note: all the fields in each slice are sorted by symbol value. | |
209 | let mut left = self.0.next()?; | |
210 | let mut right = self.1.next()?; | |
211 | loop { | |
212 | match left.0.cmp(&right.0) { | |
213 | Ordering::Equal => return Some((&left.1, &right.1)), | |
214 | Ordering::Less => left = self.0.next()?, | |
215 | Ordering::Greater => right = self.1.next()?, | |
216 | } | |
217 | } | |
218 | } | |
219 | } | |
220 | Iter(left.iter(), right.iter()) | |
221 | } | |
222 | ||
923072b8 | 223 | #[expect(clippy::similar_names)] |
5e7ed085 | 224 | impl<'a> NormalizedPat<'a> { |
923072b8 | 225 | #[expect(clippy::too_many_lines)] |
5e7ed085 FG |
226 | fn from_pat(cx: &LateContext<'_>, arena: &'a DroplessArena, pat: &'a Pat<'_>) -> Self { |
227 | match pat.kind { | |
228 | PatKind::Wild | PatKind::Binding(.., None) => Self::Wild, | |
229 | PatKind::Binding(.., Some(pat)) | PatKind::Box(pat) | PatKind::Ref(pat, _) => { | |
230 | Self::from_pat(cx, arena, pat) | |
231 | }, | |
232 | PatKind::Struct(ref path, fields, _) => { | |
233 | let fields = | |
234 | arena.alloc_from_iter(fields.iter().map(|f| (f.ident.name, Self::from_pat(cx, arena, f.pat)))); | |
235 | fields.sort_by_key(|&(name, _)| name); | |
236 | Self::Struct(cx.qpath_res(path, pat.hir_id).opt_def_id(), fields) | |
237 | }, | |
238 | PatKind::TupleStruct(ref path, pats, wild_idx) => { | |
239 | let adt = match cx.typeck_results().pat_ty(pat).ty_adt_def() { | |
240 | Some(x) => x, | |
241 | None => return Self::Wild, | |
242 | }; | |
243 | let (var_id, variant) = if adt.is_enum() { | |
244 | match cx.qpath_res(path, pat.hir_id).opt_def_id() { | |
245 | Some(x) => (Some(x), adt.variant_with_ctor_id(x)), | |
246 | None => return Self::Wild, | |
247 | } | |
248 | } else { | |
249 | (None, adt.non_enum_variant()) | |
250 | }; | |
251 | let (front, back) = match wild_idx { | |
252 | Some(i) => pats.split_at(i), | |
253 | None => (pats, [].as_slice()), | |
254 | }; | |
255 | let pats = arena.alloc_from_iter( | |
256 | front | |
257 | .iter() | |
258 | .map(|pat| Self::from_pat(cx, arena, pat)) | |
259 | .chain(iter::repeat_with(|| Self::Wild).take(variant.fields.len() - pats.len())) | |
260 | .chain(back.iter().map(|pat| Self::from_pat(cx, arena, pat))), | |
261 | ); | |
262 | Self::Tuple(var_id, pats) | |
263 | }, | |
264 | PatKind::Or(pats) => Self::Or(arena.alloc_from_iter(pats.iter().map(|pat| Self::from_pat(cx, arena, pat)))), | |
265 | PatKind::Path(ref path) => Self::Path(cx.qpath_res(path, pat.hir_id).opt_def_id()), | |
266 | PatKind::Tuple(pats, wild_idx) => { | |
267 | let field_count = match cx.typeck_results().pat_ty(pat).kind() { | |
268 | ty::Tuple(subs) => subs.len(), | |
269 | _ => return Self::Wild, | |
270 | }; | |
271 | let (front, back) = match wild_idx { | |
272 | Some(i) => pats.split_at(i), | |
273 | None => (pats, [].as_slice()), | |
274 | }; | |
275 | let pats = arena.alloc_from_iter( | |
276 | front | |
277 | .iter() | |
278 | .map(|pat| Self::from_pat(cx, arena, pat)) | |
279 | .chain(iter::repeat_with(|| Self::Wild).take(field_count - pats.len())) | |
280 | .chain(back.iter().map(|pat| Self::from_pat(cx, arena, pat))), | |
281 | ); | |
282 | Self::Tuple(None, pats) | |
283 | }, | |
284 | PatKind::Lit(e) => match &e.kind { | |
285 | // TODO: Handle negative integers. They're currently treated as a wild match. | |
286 | ExprKind::Lit(lit) => match lit.node { | |
287 | LitKind::Str(sym, _) => Self::LitStr(sym), | |
288 | LitKind::ByteStr(ref bytes) => Self::LitBytes(&**bytes), | |
289 | LitKind::Byte(val) => Self::LitInt(val.into()), | |
290 | LitKind::Char(val) => Self::LitInt(val.into()), | |
291 | LitKind::Int(val, _) => Self::LitInt(val), | |
292 | LitKind::Bool(val) => Self::LitBool(val), | |
293 | LitKind::Float(..) | LitKind::Err(_) => Self::Wild, | |
294 | }, | |
295 | _ => Self::Wild, | |
296 | }, | |
297 | PatKind::Range(start, end, bounds) => { | |
298 | // TODO: Handle negative integers. They're currently treated as a wild match. | |
299 | let start = match start { | |
300 | None => 0, | |
301 | Some(e) => match &e.kind { | |
302 | ExprKind::Lit(lit) => match lit.node { | |
303 | LitKind::Int(val, _) => val, | |
304 | LitKind::Char(val) => val.into(), | |
305 | LitKind::Byte(val) => val.into(), | |
306 | _ => return Self::Wild, | |
307 | }, | |
308 | _ => return Self::Wild, | |
309 | }, | |
310 | }; | |
311 | let (end, bounds) = match end { | |
312 | None => (u128::MAX, RangeEnd::Included), | |
313 | Some(e) => match &e.kind { | |
314 | ExprKind::Lit(lit) => match lit.node { | |
315 | LitKind::Int(val, _) => (val, bounds), | |
316 | LitKind::Char(val) => (val.into(), bounds), | |
317 | LitKind::Byte(val) => (val.into(), bounds), | |
318 | _ => return Self::Wild, | |
319 | }, | |
320 | _ => return Self::Wild, | |
321 | }, | |
322 | }; | |
323 | Self::Range(PatRange { start, end, bounds }) | |
324 | }, | |
325 | PatKind::Slice(front, wild_pat, back) => Self::Slice( | |
326 | arena.alloc_from_iter(front.iter().map(|pat| Self::from_pat(cx, arena, pat))), | |
327 | wild_pat.map(|_| &*arena.alloc_from_iter(back.iter().map(|pat| Self::from_pat(cx, arena, pat)))), | |
328 | ), | |
329 | } | |
330 | } | |
331 | ||
332 | /// Checks if two patterns overlap in the values they can match assuming they are for the same | |
333 | /// type. | |
334 | fn has_overlapping_values(&self, other: &Self) -> bool { | |
335 | match (*self, *other) { | |
336 | (Self::Wild, _) | (_, Self::Wild) => true, | |
337 | (Self::Or(pats), ref other) | (ref other, Self::Or(pats)) => { | |
338 | pats.iter().any(|pat| pat.has_overlapping_values(other)) | |
339 | }, | |
340 | (Self::Struct(lpath, lfields), Self::Struct(rpath, rfields)) => { | |
341 | if lpath != rpath { | |
342 | return false; | |
343 | } | |
344 | iter_matching_struct_fields(lfields, rfields).all(|(lpat, rpat)| lpat.has_overlapping_values(rpat)) | |
345 | }, | |
346 | (Self::Tuple(lpath, lpats), Self::Tuple(rpath, rpats)) => { | |
347 | if lpath != rpath { | |
348 | return false; | |
349 | } | |
350 | lpats | |
351 | .iter() | |
352 | .zip(rpats.iter()) | |
353 | .all(|(lpat, rpat)| lpat.has_overlapping_values(rpat)) | |
354 | }, | |
355 | (Self::Path(x), Self::Path(y)) => x == y, | |
356 | (Self::LitStr(x), Self::LitStr(y)) => x == y, | |
357 | (Self::LitBytes(x), Self::LitBytes(y)) => x == y, | |
358 | (Self::LitInt(x), Self::LitInt(y)) => x == y, | |
359 | (Self::LitBool(x), Self::LitBool(y)) => x == y, | |
360 | (Self::Range(ref x), Self::Range(ref y)) => x.overlaps(y), | |
361 | (Self::Range(ref range), Self::LitInt(x)) | (Self::LitInt(x), Self::Range(ref range)) => range.contains(x), | |
362 | (Self::Slice(lpats, None), Self::Slice(rpats, None)) => { | |
363 | lpats.len() == rpats.len() && lpats.iter().zip(rpats.iter()).all(|(x, y)| x.has_overlapping_values(y)) | |
364 | }, | |
365 | (Self::Slice(pats, None), Self::Slice(front, Some(back))) | |
366 | | (Self::Slice(front, Some(back)), Self::Slice(pats, None)) => { | |
367 | // Here `pats` is an exact size match. If the combined lengths of `front` and `back` are greater | |
368 | // then the minium length required will be greater than the length of `pats`. | |
369 | if pats.len() < front.len() + back.len() { | |
370 | return false; | |
371 | } | |
372 | pats[..front.len()] | |
373 | .iter() | |
374 | .zip(front.iter()) | |
375 | .chain(pats[pats.len() - back.len()..].iter().zip(back.iter())) | |
376 | .all(|(x, y)| x.has_overlapping_values(y)) | |
377 | }, | |
378 | (Self::Slice(lfront, Some(lback)), Self::Slice(rfront, Some(rback))) => lfront | |
379 | .iter() | |
380 | .zip(rfront.iter()) | |
381 | .chain(lback.iter().rev().zip(rback.iter().rev())) | |
382 | .all(|(x, y)| x.has_overlapping_values(y)), | |
383 | ||
384 | // Enums can mix unit variants with tuple/struct variants. These can never overlap. | |
385 | (Self::Path(_), Self::Tuple(..) | Self::Struct(..)) | |
386 | | (Self::Tuple(..) | Self::Struct(..), Self::Path(_)) => false, | |
387 | ||
388 | // Tuples can be matched like a struct. | |
389 | (Self::Tuple(x, _), Self::Struct(y, _)) | (Self::Struct(x, _), Self::Tuple(y, _)) => { | |
390 | // TODO: check fields here. | |
391 | x == y | |
392 | }, | |
393 | ||
394 | // TODO: Lit* with Path, Range with Path, LitBytes with Slice | |
395 | _ => true, | |
396 | } | |
397 | } | |
398 | } | |
399 | ||
5099ac24 FG |
400 | fn pat_contains_local(pat: &Pat<'_>, id: HirId) -> bool { |
401 | let mut result = false; | |
402 | pat.walk_short(|p| { | |
403 | result |= matches!(p.kind, PatKind::Binding(_, binding_id, ..) if binding_id == id); | |
404 | !result | |
405 | }); | |
406 | result | |
407 | } | |
408 | ||
409 | /// Returns true if all the bindings in the `Pat` are in `ids` and vice versa | |
410 | fn bindings_eq(pat: &Pat<'_>, mut ids: HirIdSet) -> bool { | |
411 | let mut result = true; | |
412 | pat.each_binding_or_first(&mut |_, id, _, _| result &= ids.remove(&id)); | |
413 | result && ids.is_empty() | |
414 | } |