]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs
Rollup merge of #99770 - Nilstrieb:mir-pass-unit-test, r=oli-obk
[rust.git] / compiler / rustc_mir_build / src / thir / pattern / const_to_pat.rs
1 use rustc_hir as hir;
2 use rustc_index::vec::Idx;
3 use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
4 use rustc_middle::mir::{self, Field};
5 use rustc_middle::thir::{FieldPat, Pat, PatKind};
6 use rustc_middle::ty::print::with_no_trimmed_paths;
7 use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
8 use rustc_session::lint;
9 use rustc_span::Span;
10 use rustc_trait_selection::traits::predicate_for_trait_def;
11 use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
12 use rustc_trait_selection::traits::{self, ObligationCause, PredicateObligation};
13
14 use std::cell::Cell;
15
16 use super::PatCtxt;
17
18 impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
19     /// Converts an evaluated constant to a pattern (if possible).
20     /// This means aggregate values (like structs and enums) are converted
21     /// to a pattern that matches the value (as if you'd compared via structural equality).
22     #[instrument(level = "debug", skip(self))]
23     pub(super) fn const_to_pat(
24         &self,
25         cv: mir::ConstantKind<'tcx>,
26         id: hir::HirId,
27         span: Span,
28         mir_structural_match_violation: bool,
29     ) -> Pat<'tcx> {
30         let pat = self.tcx.infer_ctxt().enter(|infcx| {
31             let mut convert = ConstToPat::new(self, id, span, infcx);
32             convert.to_pat(cv, mir_structural_match_violation)
33         });
34
35         debug!(?pat);
36         pat
37     }
38 }
39
40 struct ConstToPat<'a, 'tcx> {
41     id: hir::HirId,
42     span: Span,
43     param_env: ty::ParamEnv<'tcx>,
44
45     // This tracks if we emitted some hard error for a given const value, so that
46     // we will not subsequently issue an irrelevant lint for the same const
47     // value.
48     saw_const_match_error: Cell<bool>,
49
50     // This tracks if we emitted some diagnostic for a given const value, so that
51     // we will not subsequently issue an irrelevant lint for the same const
52     // value.
53     saw_const_match_lint: Cell<bool>,
54
55     // For backcompat we need to keep allowing non-structurally-eq types behind references.
56     // See also all the `cant-hide-behind` tests.
57     behind_reference: Cell<bool>,
58
59     // inference context used for checking `T: Structural` bounds.
60     infcx: InferCtxt<'a, 'tcx>,
61
62     include_lint_checks: bool,
63
64     treat_byte_string_as_slice: bool,
65 }
66
67 mod fallback_to_const_ref {
68     #[derive(Debug)]
69     /// This error type signals that we encountered a non-struct-eq situation behind a reference.
70     /// We bubble this up in order to get back to the reference destructuring and make that emit
71     /// a const pattern instead of a deref pattern. This allows us to simply call `PartialEq::eq`
72     /// on such patterns (since that function takes a reference) and not have to jump through any
73     /// hoops to get a reference to the value.
74     pub(super) struct FallbackToConstRef(());
75
76     pub(super) fn fallback_to_const_ref<'a, 'tcx>(
77         c2p: &super::ConstToPat<'a, 'tcx>,
78     ) -> FallbackToConstRef {
79         assert!(c2p.behind_reference.get());
80         FallbackToConstRef(())
81     }
82 }
83 use fallback_to_const_ref::{fallback_to_const_ref, FallbackToConstRef};
84
85 impl<'a, 'tcx> ConstToPat<'a, 'tcx> {
86     fn new(
87         pat_ctxt: &PatCtxt<'_, 'tcx>,
88         id: hir::HirId,
89         span: Span,
90         infcx: InferCtxt<'a, 'tcx>,
91     ) -> Self {
92         trace!(?pat_ctxt.typeck_results.hir_owner);
93         ConstToPat {
94             id,
95             span,
96             infcx,
97             param_env: pat_ctxt.param_env,
98             include_lint_checks: pat_ctxt.include_lint_checks,
99             saw_const_match_error: Cell::new(false),
100             saw_const_match_lint: Cell::new(false),
101             behind_reference: Cell::new(false),
102             treat_byte_string_as_slice: pat_ctxt
103                 .typeck_results
104                 .treat_byte_string_as_slice
105                 .contains(&id.local_id),
106         }
107     }
108
109     fn tcx(&self) -> TyCtxt<'tcx> {
110         self.infcx.tcx
111     }
112
113     fn adt_derive_msg(&self, adt_def: AdtDef<'tcx>) -> String {
114         let path = self.tcx().def_path_str(adt_def.did());
115         format!(
116             "to use a constant of type `{}` in a pattern, \
117             `{}` must be annotated with `#[derive(PartialEq, Eq)]`",
118             path, path,
119         )
120     }
121
122     fn search_for_structural_match_violation(&self, ty: Ty<'tcx>) -> Option<String> {
123         traits::search_for_structural_match_violation(self.span, self.tcx(), ty).map(|non_sm_ty| {
124             with_no_trimmed_paths!(match non_sm_ty.kind() {
125                 ty::Adt(adt, _) => self.adt_derive_msg(*adt),
126                 ty::Dynamic(..) => {
127                     "trait objects cannot be used in patterns".to_string()
128                 }
129                 ty::Opaque(..) => {
130                     "opaque types cannot be used in patterns".to_string()
131                 }
132                 ty::Closure(..) => {
133                     "closures cannot be used in patterns".to_string()
134                 }
135                 ty::Generator(..) | ty::GeneratorWitness(..) => {
136                     "generators cannot be used in patterns".to_string()
137                 }
138                 ty::Float(..) => {
139                     "floating-point numbers cannot be used in patterns".to_string()
140                 }
141                 ty::FnPtr(..) => {
142                     "function pointers cannot be used in patterns".to_string()
143                 }
144                 ty::RawPtr(..) => {
145                     "raw pointers cannot be used in patterns".to_string()
146                 }
147                 _ => {
148                     bug!("use of a value of `{non_sm_ty}` inside a pattern")
149                 }
150             })
151         })
152     }
153
154     fn type_marked_structural(&self, ty: Ty<'tcx>) -> bool {
155         ty.is_structural_eq_shallow(self.infcx.tcx)
156     }
157
158     fn to_pat(
159         &mut self,
160         cv: mir::ConstantKind<'tcx>,
161         mir_structural_match_violation: bool,
162     ) -> Pat<'tcx> {
163         trace!(self.treat_byte_string_as_slice);
164         // This method is just a wrapper handling a validity check; the heavy lifting is
165         // performed by the recursive `recur` method, which is not meant to be
166         // invoked except by this method.
167         //
168         // once indirect_structural_match is a full fledged error, this
169         // level of indirection can be eliminated
170
171         let inlined_const_as_pat =
172             self.recur(cv, mir_structural_match_violation).unwrap_or_else(|_| Pat {
173                 span: self.span,
174                 ty: cv.ty(),
175                 kind: Box::new(PatKind::Constant { value: cv }),
176             });
177
178         if self.include_lint_checks && !self.saw_const_match_error.get() {
179             // If we were able to successfully convert the const to some pat,
180             // double-check that all types in the const implement `Structural`.
181
182             let structural = self.search_for_structural_match_violation(cv.ty());
183             debug!(
184                 "search_for_structural_match_violation cv.ty: {:?} returned: {:?}",
185                 cv.ty(),
186                 structural
187             );
188
189             // This can occur because const qualification treats all associated constants as
190             // opaque, whereas `search_for_structural_match_violation` tries to monomorphize them
191             // before it runs.
192             //
193             // FIXME(#73448): Find a way to bring const qualification into parity with
194             // `search_for_structural_match_violation`.
195             if structural.is_none() && mir_structural_match_violation {
196                 warn!("MIR const-checker found novel structural match violation. See #73448.");
197                 return inlined_const_as_pat;
198             }
199
200             if let Some(msg) = structural {
201                 if !self.type_may_have_partial_eq_impl(cv.ty()) {
202                     // span_fatal avoids ICE from resolution of non-existent method (rare case).
203                     self.tcx().sess.span_fatal(self.span, &msg);
204                 } else if mir_structural_match_violation && !self.saw_const_match_lint.get() {
205                     self.tcx().struct_span_lint_hir(
206                         lint::builtin::INDIRECT_STRUCTURAL_MATCH,
207                         self.id,
208                         self.span,
209                         |lint| {
210                             lint.build(&msg).emit();
211                         },
212                     );
213                 } else {
214                     debug!(
215                         "`search_for_structural_match_violation` found one, but `CustomEq` was \
216                           not in the qualifs for that `const`"
217                     );
218                 }
219             }
220         }
221
222         inlined_const_as_pat
223     }
224
225     fn type_may_have_partial_eq_impl(&self, ty: Ty<'tcx>) -> bool {
226         // double-check there even *is* a semantic `PartialEq` to dispatch to.
227         //
228         // (If there isn't, then we can safely issue a hard
229         // error, because that's never worked, due to compiler
230         // using `PartialEq::eq` in this scenario in the past.)
231         let partial_eq_trait_id =
232             self.tcx().require_lang_item(hir::LangItem::PartialEq, Some(self.span));
233         let obligation: PredicateObligation<'_> = predicate_for_trait_def(
234             self.tcx(),
235             self.param_env,
236             ObligationCause::misc(self.span, self.id),
237             partial_eq_trait_id,
238             0,
239             ty,
240             &[],
241         );
242         // FIXME: should this call a `predicate_must_hold` variant instead?
243
244         let has_impl = self.infcx.predicate_may_hold(&obligation);
245
246         // Note: To fix rust-lang/rust#65466, we could just remove this type
247         // walk hack for function pointers, and unconditionally error
248         // if `PartialEq` is not implemented. However, that breaks stable
249         // code at the moment, because types like `for <'a> fn(&'a ())` do
250         // not *yet* implement `PartialEq`. So for now we leave this here.
251         has_impl
252             || ty.walk().any(|t| match t.unpack() {
253                 ty::subst::GenericArgKind::Lifetime(_) => false,
254                 ty::subst::GenericArgKind::Type(t) => t.is_fn_ptr(),
255                 ty::subst::GenericArgKind::Const(_) => false,
256             })
257     }
258
259     fn field_pats(
260         &self,
261         vals: impl Iterator<Item = mir::ConstantKind<'tcx>>,
262     ) -> Result<Vec<FieldPat<'tcx>>, FallbackToConstRef> {
263         vals.enumerate()
264             .map(|(idx, val)| {
265                 let field = Field::new(idx);
266                 Ok(FieldPat { field, pattern: self.recur(val, false)? })
267             })
268             .collect()
269     }
270
271     // Recursive helper for `to_pat`; invoke that (instead of calling this directly).
272     #[instrument(skip(self), level = "debug")]
273     fn recur(
274         &self,
275         cv: mir::ConstantKind<'tcx>,
276         mir_structural_match_violation: bool,
277     ) -> Result<Pat<'tcx>, FallbackToConstRef> {
278         let id = self.id;
279         let span = self.span;
280         let tcx = self.tcx();
281         let param_env = self.param_env;
282
283         let kind = match cv.ty().kind() {
284             ty::Float(_) => {
285                 if self.include_lint_checks {
286                     tcx.struct_span_lint_hir(
287                         lint::builtin::ILLEGAL_FLOATING_POINT_LITERAL_PATTERN,
288                         id,
289                         span,
290                         |lint| {
291                             lint.build("floating-point types cannot be used in patterns").emit();
292                         },
293                     );
294                 }
295                 PatKind::Constant { value: cv }
296             }
297             ty::Adt(adt_def, _) if adt_def.is_union() => {
298                 // Matching on union fields is unsafe, we can't hide it in constants
299                 self.saw_const_match_error.set(true);
300                 let msg = "cannot use unions in constant patterns";
301                 if self.include_lint_checks {
302                     tcx.sess.span_err(span, msg);
303                 } else {
304                     tcx.sess.delay_span_bug(span, msg);
305                 }
306                 PatKind::Wild
307             }
308             ty::Adt(..)
309                 if !self.type_may_have_partial_eq_impl(cv.ty())
310                     // FIXME(#73448): Find a way to bring const qualification into parity with
311                     // `search_for_structural_match_violation` and then remove this condition.
312                     && self.search_for_structural_match_violation(cv.ty()).is_some() =>
313             {
314                 // Obtain the actual type that isn't annotated. If we just looked at `cv.ty` we
315                 // could get `Option<NonStructEq>`, even though `Option` is annotated with derive.
316                 let msg = self.search_for_structural_match_violation(cv.ty()).unwrap();
317                 self.saw_const_match_error.set(true);
318                 if self.include_lint_checks {
319                     tcx.sess.span_err(self.span, &msg);
320                 } else {
321                     tcx.sess.delay_span_bug(self.span, &msg);
322                 }
323                 PatKind::Wild
324             }
325             // If the type is not structurally comparable, just emit the constant directly,
326             // causing the pattern match code to treat it opaquely.
327             // FIXME: This code doesn't emit errors itself, the caller emits the errors.
328             // So instead of specific errors, you just get blanket errors about the whole
329             // const type. See
330             // https://github.com/rust-lang/rust/pull/70743#discussion_r404701963 for
331             // details.
332             // Backwards compatibility hack because we can't cause hard errors on these
333             // types, so we compare them via `PartialEq::eq` at runtime.
334             ty::Adt(..) if !self.type_marked_structural(cv.ty()) && self.behind_reference.get() => {
335                 if self.include_lint_checks
336                     && !self.saw_const_match_error.get()
337                     && !self.saw_const_match_lint.get()
338                 {
339                     self.saw_const_match_lint.set(true);
340                     tcx.struct_span_lint_hir(
341                         lint::builtin::INDIRECT_STRUCTURAL_MATCH,
342                         id,
343                         span,
344                         |lint| {
345                             let msg = format!(
346                                 "to use a constant of type `{}` in a pattern, \
347                                  `{}` must be annotated with `#[derive(PartialEq, Eq)]`",
348                                 cv.ty(),
349                                 cv.ty(),
350                             );
351                             lint.build(&msg).emit();
352                         },
353                     );
354                 }
355                 // Since we are behind a reference, we can just bubble the error up so we get a
356                 // constant at reference type, making it easy to let the fallback call
357                 // `PartialEq::eq` on it.
358                 return Err(fallback_to_const_ref(self));
359             }
360             ty::Adt(adt_def, _) if !self.type_marked_structural(cv.ty()) => {
361                 debug!(
362                     "adt_def {:?} has !type_marked_structural for cv.ty: {:?}",
363                     adt_def,
364                     cv.ty()
365                 );
366                 let path = tcx.def_path_str(adt_def.did());
367                 let msg = format!(
368                     "to use a constant of type `{}` in a pattern, \
369                      `{}` must be annotated with `#[derive(PartialEq, Eq)]`",
370                     path, path,
371                 );
372                 self.saw_const_match_error.set(true);
373                 if self.include_lint_checks {
374                     tcx.sess.span_err(span, &msg);
375                 } else {
376                     tcx.sess.delay_span_bug(span, &msg);
377                 }
378                 PatKind::Wild
379             }
380             ty::Adt(adt_def, substs) if adt_def.is_enum() => {
381                 let destructured = tcx.destructure_mir_constant(param_env, cv);
382
383                 PatKind::Variant {
384                     adt_def: *adt_def,
385                     substs,
386                     variant_index: destructured
387                         .variant
388                         .expect("destructed const of adt without variant id"),
389                     subpatterns: self.field_pats(destructured.fields.iter().copied())?,
390                 }
391             }
392             ty::Tuple(_) | ty::Adt(_, _) => {
393                 let destructured = tcx.destructure_mir_constant(param_env, cv);
394                 PatKind::Leaf { subpatterns: self.field_pats(destructured.fields.iter().copied())? }
395             }
396             ty::Array(..) => PatKind::Array {
397                 prefix: tcx
398                     .destructure_mir_constant(param_env, cv)
399                     .fields
400                     .iter()
401                     .map(|val| self.recur(*val, false))
402                     .collect::<Result<_, _>>()?,
403                 slice: None,
404                 suffix: Vec::new(),
405             },
406             ty::Ref(_, pointee_ty, ..) => match *pointee_ty.kind() {
407                 // These are not allowed and will error elsewhere anyway.
408                 ty::Dynamic(..) => {
409                     self.saw_const_match_error.set(true);
410                     let msg = format!("`{}` cannot be used in patterns", cv.ty());
411                     if self.include_lint_checks {
412                         tcx.sess.span_err(span, &msg);
413                     } else {
414                         tcx.sess.delay_span_bug(span, &msg);
415                     }
416                     PatKind::Wild
417                 }
418                 // `&str` is represented as `ConstValue::Slice`, let's keep using this
419                 // optimization for now.
420                 ty::Str => PatKind::Constant { value: cv },
421                 // `b"foo"` produces a `&[u8; 3]`, but you can't use constants of array type when
422                 // matching against references, you can only use byte string literals.
423                 // The typechecker has a special case for byte string literals, by treating them
424                 // as slices. This means we turn `&[T; N]` constants into slice patterns, which
425                 // has no negative effects on pattern matching, even if we're actually matching on
426                 // arrays.
427                 ty::Array(..) if !self.treat_byte_string_as_slice => {
428                     let old = self.behind_reference.replace(true);
429                     let array = tcx.deref_mir_constant(self.param_env.and(cv));
430                     let val = PatKind::Deref {
431                         subpattern: Pat {
432                             kind: Box::new(PatKind::Array {
433                                 prefix: tcx
434                                     .destructure_mir_constant(param_env, array)
435                                     .fields
436                                     .iter()
437                                     .map(|val| self.recur(*val, false))
438                                     .collect::<Result<_, _>>()?,
439                                 slice: None,
440                                 suffix: vec![],
441                             }),
442                             span,
443                             ty: *pointee_ty,
444                         },
445                     };
446                     self.behind_reference.set(old);
447                     val
448                 }
449                 ty::Array(elem_ty, _) |
450                 // Cannot merge this with the catch all branch below, because the `const_deref`
451                 // changes the type from slice to array, we need to keep the original type in the
452                 // pattern.
453                 ty::Slice(elem_ty) => {
454                     let old = self.behind_reference.replace(true);
455                     let array = tcx.deref_mir_constant(self.param_env.and(cv));
456                     let val = PatKind::Deref {
457                         subpattern: Pat {
458                             kind: Box::new(PatKind::Slice {
459                                 prefix: tcx
460                                     .destructure_mir_constant(param_env, array)
461                                     .fields
462                                     .iter()
463                                     .map(|val| self.recur(*val, false))
464                                     .collect::<Result<_, _>>()?,
465                                 slice: None,
466                                 suffix: vec![],
467                             }),
468                             span,
469                             ty: tcx.mk_slice(elem_ty),
470                         },
471                     };
472                     self.behind_reference.set(old);
473                     val
474                 }
475                 // Backwards compatibility hack: support references to non-structural types.
476                 // We'll lower
477                 // this pattern to a `PartialEq::eq` comparison and `PartialEq::eq` takes a
478                 // reference. This makes the rest of the matching logic simpler as it doesn't have
479                 // to figure out how to get a reference again.
480                 ty::Adt(adt_def, _) if !self.type_marked_structural(*pointee_ty) => {
481                     if self.behind_reference.get() {
482                         if self.include_lint_checks
483                             && !self.saw_const_match_error.get()
484                             && !self.saw_const_match_lint.get()
485                         {
486                             self.saw_const_match_lint.set(true);
487                             let msg = self.adt_derive_msg(adt_def);
488                             self.tcx().struct_span_lint_hir(
489                                 lint::builtin::INDIRECT_STRUCTURAL_MATCH,
490                                 self.id,
491                                 self.span,
492                                 |lint| {lint.build(&msg).emit();},
493                             );
494                         }
495                         PatKind::Constant { value: cv }
496                     } else {
497                         if !self.saw_const_match_error.get() {
498                             self.saw_const_match_error.set(true);
499                             let msg = self.adt_derive_msg(adt_def);
500                             if self.include_lint_checks {
501                                 tcx.sess.span_err(span, &msg);
502                             } else {
503                                 tcx.sess.delay_span_bug(span, &msg);
504                             }
505                         }
506                         PatKind::Wild
507                     }
508                 }
509                 // All other references are converted into deref patterns and then recursively
510                 // convert the dereferenced constant to a pattern that is the sub-pattern of the
511                 // deref pattern.
512                 _ => {
513                     if !pointee_ty.is_sized(tcx.at(span), param_env) {
514                         // `tcx.deref_mir_constant()` below will ICE with an unsized type
515                         // (except slices, which are handled in a separate arm above).
516                         let msg = format!("cannot use unsized non-slice type `{}` in constant patterns", pointee_ty);
517                         if self.include_lint_checks {
518                             tcx.sess.span_err(span, &msg);
519                         } else {
520                             tcx.sess.delay_span_bug(span, &msg);
521                         }
522                         PatKind::Wild
523                     } else {
524                         let old = self.behind_reference.replace(true);
525                         // In case there are structural-match violations somewhere in this subpattern,
526                         // we fall back to a const pattern. If we do not do this, we may end up with
527                         // a !structural-match constant that is not of reference type, which makes it
528                         // very hard to invoke `PartialEq::eq` on it as a fallback.
529                         let val = match self.recur(tcx.deref_mir_constant(self.param_env.and(cv)), false) {
530                             Ok(subpattern) => PatKind::Deref { subpattern },
531                             Err(_) => PatKind::Constant { value: cv },
532                         };
533                         self.behind_reference.set(old);
534                         val
535                     }
536                 }
537             },
538             ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::FnDef(..) => {
539                 PatKind::Constant { value: cv }
540             }
541             ty::RawPtr(pointee) if pointee.ty.is_sized(tcx.at(span), param_env) => {
542                 PatKind::Constant { value: cv }
543             }
544             // FIXME: these can have very surprising behaviour where optimization levels or other
545             // compilation choices change the runtime behaviour of the match.
546             // See https://github.com/rust-lang/rust/issues/70861 for examples.
547             ty::FnPtr(..) | ty::RawPtr(..) => {
548                 if self.include_lint_checks
549                     && !self.saw_const_match_error.get()
550                     && !self.saw_const_match_lint.get()
551                 {
552                     self.saw_const_match_lint.set(true);
553                     let msg = "function pointers and unsized pointers in patterns behave \
554                         unpredictably and should not be relied upon. \
555                         See https://github.com/rust-lang/rust/issues/70861 for details.";
556                     tcx.struct_span_lint_hir(
557                         lint::builtin::POINTER_STRUCTURAL_MATCH,
558                         id,
559                         span,
560                         |lint| {
561                             lint.build(msg).emit();
562                         },
563                     );
564                 }
565                 PatKind::Constant { value: cv }
566             }
567             _ => {
568                 self.saw_const_match_error.set(true);
569                 let msg = format!("`{}` cannot be used in patterns", cv.ty());
570                 if self.include_lint_checks {
571                     tcx.sess.span_err(span, &msg);
572                 } else {
573                     tcx.sess.delay_span_bug(span, &msg);
574                 }
575                 PatKind::Wild
576             }
577         };
578
579         if self.include_lint_checks
580             && !self.saw_const_match_error.get()
581             && !self.saw_const_match_lint.get()
582             && mir_structural_match_violation
583             // FIXME(#73448): Find a way to bring const qualification into parity with
584             // `search_for_structural_match_violation` and then remove this condition.
585             && self.search_for_structural_match_violation(cv.ty()).is_some()
586         {
587             self.saw_const_match_lint.set(true);
588             // Obtain the actual type that isn't annotated. If we just looked at `cv.ty` we
589             // could get `Option<NonStructEq>`, even though `Option` is annotated with derive.
590             let msg = self.search_for_structural_match_violation(cv.ty()).unwrap().replace(
591                 "in a pattern,",
592                 "in a pattern, the constant's initializer must be trivial or",
593             );
594             tcx.struct_span_lint_hir(
595                 lint::builtin::NONTRIVIAL_STRUCTURAL_MATCH,
596                 id,
597                 span,
598                 |lint| {
599                     lint.build(&msg).emit();
600                 },
601             );
602         }
603
604         Ok(Pat { span, ty: cv.ty(), kind: Box::new(kind) })
605     }
606 }