]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs
Rollup merge of #103805 - Mark-Simulacrum:forward-port, r=jyn514
[rust.git] / compiler / rustc_mir_build / src / thir / pattern / const_to_pat.rs
1 use rustc_errors::DelayDm;
2 use rustc_hir as hir;
3 use rustc_index::vec::Idx;
4 use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
5 use rustc_middle::mir::{self, Field};
6 use rustc_middle::thir::{FieldPat, Pat, PatKind};
7 use rustc_middle::ty::print::with_no_trimmed_paths;
8 use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
9 use rustc_session::lint;
10 use rustc_span::Span;
11 use rustc_trait_selection::traits::predicate_for_trait_def;
12 use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
13 use rustc_trait_selection::traits::{self, ObligationCause, PredicateObligation};
14
15 use std::cell::Cell;
16
17 use super::PatCtxt;
18
19 impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
20     /// Converts an evaluated constant to a pattern (if possible).
21     /// This means aggregate values (like structs and enums) are converted
22     /// to a pattern that matches the value (as if you'd compared via structural equality).
23     #[instrument(level = "debug", skip(self), ret)]
24     pub(super) fn const_to_pat(
25         &self,
26         cv: mir::ConstantKind<'tcx>,
27         id: hir::HirId,
28         span: Span,
29         mir_structural_match_violation: bool,
30     ) -> Box<Pat<'tcx>> {
31         let infcx = self.tcx.infer_ctxt().build();
32         let mut convert = ConstToPat::new(self, id, span, infcx);
33         convert.to_pat(cv, mir_structural_match_violation)
34     }
35 }
36
37 struct ConstToPat<'tcx> {
38     id: hir::HirId,
39     span: Span,
40     param_env: ty::ParamEnv<'tcx>,
41
42     // This tracks if we emitted some hard error for a given const value, so that
43     // we will not subsequently issue an irrelevant lint for the same const
44     // value.
45     saw_const_match_error: Cell<bool>,
46
47     // This tracks if we emitted some diagnostic for a given const value, so that
48     // we will not subsequently issue an irrelevant lint for the same const
49     // value.
50     saw_const_match_lint: Cell<bool>,
51
52     // For backcompat we need to keep allowing non-structurally-eq types behind references.
53     // See also all the `cant-hide-behind` tests.
54     behind_reference: Cell<bool>,
55
56     // inference context used for checking `T: Structural` bounds.
57     infcx: InferCtxt<'tcx>,
58
59     include_lint_checks: bool,
60
61     treat_byte_string_as_slice: bool,
62 }
63
64 mod fallback_to_const_ref {
65     #[derive(Debug)]
66     /// This error type signals that we encountered a non-struct-eq situation behind a reference.
67     /// We bubble this up in order to get back to the reference destructuring and make that emit
68     /// a const pattern instead of a deref pattern. This allows us to simply call `PartialEq::eq`
69     /// on such patterns (since that function takes a reference) and not have to jump through any
70     /// hoops to get a reference to the value.
71     pub(super) struct FallbackToConstRef(());
72
73     pub(super) fn fallback_to_const_ref<'tcx>(c2p: &super::ConstToPat<'tcx>) -> FallbackToConstRef {
74         assert!(c2p.behind_reference.get());
75         FallbackToConstRef(())
76     }
77 }
78 use fallback_to_const_ref::{fallback_to_const_ref, FallbackToConstRef};
79
80 impl<'tcx> ConstToPat<'tcx> {
81     fn new(
82         pat_ctxt: &PatCtxt<'_, 'tcx>,
83         id: hir::HirId,
84         span: Span,
85         infcx: InferCtxt<'tcx>,
86     ) -> Self {
87         trace!(?pat_ctxt.typeck_results.hir_owner);
88         ConstToPat {
89             id,
90             span,
91             infcx,
92             param_env: pat_ctxt.param_env,
93             include_lint_checks: pat_ctxt.include_lint_checks,
94             saw_const_match_error: Cell::new(false),
95             saw_const_match_lint: Cell::new(false),
96             behind_reference: Cell::new(false),
97             treat_byte_string_as_slice: pat_ctxt
98                 .typeck_results
99                 .treat_byte_string_as_slice
100                 .contains(&id.local_id),
101         }
102     }
103
104     fn tcx(&self) -> TyCtxt<'tcx> {
105         self.infcx.tcx
106     }
107
108     fn adt_derive_msg(&self, adt_def: AdtDef<'tcx>) -> String {
109         let path = self.tcx().def_path_str(adt_def.did());
110         format!(
111             "to use a constant of type `{}` in a pattern, \
112             `{}` must be annotated with `#[derive(PartialEq, Eq)]`",
113             path, path,
114         )
115     }
116
117     fn search_for_structural_match_violation(&self, ty: Ty<'tcx>) -> Option<String> {
118         traits::search_for_structural_match_violation(self.span, self.tcx(), ty).map(|non_sm_ty| {
119             with_no_trimmed_paths!(match non_sm_ty.kind() {
120                 ty::Adt(adt, _) => self.adt_derive_msg(*adt),
121                 ty::Dynamic(..) => {
122                     "trait objects cannot be used in patterns".to_string()
123                 }
124                 ty::Opaque(..) => {
125                     "opaque types cannot be used in patterns".to_string()
126                 }
127                 ty::Closure(..) => {
128                     "closures cannot be used in patterns".to_string()
129                 }
130                 ty::Generator(..) | ty::GeneratorWitness(..) => {
131                     "generators cannot be used in patterns".to_string()
132                 }
133                 ty::Float(..) => {
134                     "floating-point numbers cannot be used in patterns".to_string()
135                 }
136                 ty::FnPtr(..) => {
137                     "function pointers cannot be used in patterns".to_string()
138                 }
139                 ty::RawPtr(..) => {
140                     "raw pointers cannot be used in patterns".to_string()
141                 }
142                 _ => {
143                     bug!("use of a value of `{non_sm_ty}` inside a pattern")
144                 }
145             })
146         })
147     }
148
149     fn type_marked_structural(&self, ty: Ty<'tcx>) -> bool {
150         ty.is_structural_eq_shallow(self.infcx.tcx)
151     }
152
153     fn to_pat(
154         &mut self,
155         cv: mir::ConstantKind<'tcx>,
156         mir_structural_match_violation: bool,
157     ) -> Box<Pat<'tcx>> {
158         trace!(self.treat_byte_string_as_slice);
159         // This method is just a wrapper handling a validity check; the heavy lifting is
160         // performed by the recursive `recur` method, which is not meant to be
161         // invoked except by this method.
162         //
163         // once indirect_structural_match is a full fledged error, this
164         // level of indirection can be eliminated
165
166         let inlined_const_as_pat =
167             self.recur(cv, mir_structural_match_violation).unwrap_or_else(|_| {
168                 Box::new(Pat {
169                     span: self.span,
170                     ty: cv.ty(),
171                     kind: PatKind::Constant { value: cv },
172                 })
173             });
174
175         if self.include_lint_checks && !self.saw_const_match_error.get() {
176             // If we were able to successfully convert the const to some pat,
177             // double-check that all types in the const implement `Structural`.
178
179             let structural = self.search_for_structural_match_violation(cv.ty());
180             debug!(
181                 "search_for_structural_match_violation cv.ty: {:?} returned: {:?}",
182                 cv.ty(),
183                 structural
184             );
185
186             // This can occur because const qualification treats all associated constants as
187             // opaque, whereas `search_for_structural_match_violation` tries to monomorphize them
188             // before it runs.
189             //
190             // FIXME(#73448): Find a way to bring const qualification into parity with
191             // `search_for_structural_match_violation`.
192             if structural.is_none() && mir_structural_match_violation {
193                 warn!("MIR const-checker found novel structural match violation. See #73448.");
194                 return inlined_const_as_pat;
195             }
196
197             if let Some(msg) = structural {
198                 if !self.type_may_have_partial_eq_impl(cv.ty()) {
199                     // span_fatal avoids ICE from resolution of non-existent method (rare case).
200                     self.tcx().sess.span_fatal(self.span, &msg);
201                 } else if mir_structural_match_violation && !self.saw_const_match_lint.get() {
202                     self.tcx().struct_span_lint_hir(
203                         lint::builtin::INDIRECT_STRUCTURAL_MATCH,
204                         self.id,
205                         self.span,
206                         msg,
207                         |lint| lint,
208                     );
209                 } else {
210                     debug!(
211                         "`search_for_structural_match_violation` found one, but `CustomEq` was \
212                           not in the qualifs for that `const`"
213                     );
214                 }
215             }
216         }
217
218         inlined_const_as_pat
219     }
220
221     fn type_may_have_partial_eq_impl(&self, ty: Ty<'tcx>) -> bool {
222         // double-check there even *is* a semantic `PartialEq` to dispatch to.
223         //
224         // (If there isn't, then we can safely issue a hard
225         // error, because that's never worked, due to compiler
226         // using `PartialEq::eq` in this scenario in the past.)
227         let partial_eq_trait_id =
228             self.tcx().require_lang_item(hir::LangItem::PartialEq, Some(self.span));
229         let obligation: PredicateObligation<'_> = predicate_for_trait_def(
230             self.tcx(),
231             self.param_env,
232             ObligationCause::misc(self.span, self.id),
233             partial_eq_trait_id,
234             0,
235             ty,
236             &[],
237         );
238         // FIXME: should this call a `predicate_must_hold` variant instead?
239
240         let has_impl = self.infcx.predicate_may_hold(&obligation);
241
242         // Note: To fix rust-lang/rust#65466, we could just remove this type
243         // walk hack for function pointers, and unconditionally error
244         // if `PartialEq` is not implemented. However, that breaks stable
245         // code at the moment, because types like `for <'a> fn(&'a ())` do
246         // not *yet* implement `PartialEq`. So for now we leave this here.
247         has_impl
248             || ty.walk().any(|t| match t.unpack() {
249                 ty::subst::GenericArgKind::Lifetime(_) => false,
250                 ty::subst::GenericArgKind::Type(t) => t.is_fn_ptr(),
251                 ty::subst::GenericArgKind::Const(_) => false,
252             })
253     }
254
255     fn field_pats(
256         &self,
257         vals: impl Iterator<Item = mir::ConstantKind<'tcx>>,
258     ) -> Result<Vec<FieldPat<'tcx>>, FallbackToConstRef> {
259         vals.enumerate()
260             .map(|(idx, val)| {
261                 let field = Field::new(idx);
262                 Ok(FieldPat { field, pattern: self.recur(val, false)? })
263             })
264             .collect()
265     }
266
267     // Recursive helper for `to_pat`; invoke that (instead of calling this directly).
268     #[instrument(skip(self), level = "debug")]
269     fn recur(
270         &self,
271         cv: mir::ConstantKind<'tcx>,
272         mir_structural_match_violation: bool,
273     ) -> Result<Box<Pat<'tcx>>, FallbackToConstRef> {
274         let id = self.id;
275         let span = self.span;
276         let tcx = self.tcx();
277         let param_env = self.param_env;
278
279         let kind = match cv.ty().kind() {
280             ty::Float(_) => {
281                 if self.include_lint_checks {
282                     tcx.struct_span_lint_hir(
283                         lint::builtin::ILLEGAL_FLOATING_POINT_LITERAL_PATTERN,
284                         id,
285                         span,
286                         "floating-point types cannot be used in patterns",
287                         |lint| lint,
288                     );
289                 }
290                 PatKind::Constant { value: cv }
291             }
292             ty::Adt(adt_def, _) if adt_def.is_union() => {
293                 // Matching on union fields is unsafe, we can't hide it in constants
294                 self.saw_const_match_error.set(true);
295                 let msg = "cannot use unions in constant patterns";
296                 if self.include_lint_checks {
297                     tcx.sess.span_err(span, msg);
298                 } else {
299                     tcx.sess.delay_span_bug(span, msg);
300                 }
301                 PatKind::Wild
302             }
303             ty::Adt(..)
304                 if !self.type_may_have_partial_eq_impl(cv.ty())
305                     // FIXME(#73448): Find a way to bring const qualification into parity with
306                     // `search_for_structural_match_violation` and then remove this condition.
307                     && self.search_for_structural_match_violation(cv.ty()).is_some() =>
308             {
309                 // Obtain the actual type that isn't annotated. If we just looked at `cv.ty` we
310                 // could get `Option<NonStructEq>`, even though `Option` is annotated with derive.
311                 let msg = self.search_for_structural_match_violation(cv.ty()).unwrap();
312                 self.saw_const_match_error.set(true);
313                 if self.include_lint_checks {
314                     tcx.sess.span_err(self.span, &msg);
315                 } else {
316                     tcx.sess.delay_span_bug(self.span, &msg);
317                 }
318                 PatKind::Wild
319             }
320             // If the type is not structurally comparable, just emit the constant directly,
321             // causing the pattern match code to treat it opaquely.
322             // FIXME: This code doesn't emit errors itself, the caller emits the errors.
323             // So instead of specific errors, you just get blanket errors about the whole
324             // const type. See
325             // https://github.com/rust-lang/rust/pull/70743#discussion_r404701963 for
326             // details.
327             // Backwards compatibility hack because we can't cause hard errors on these
328             // types, so we compare them via `PartialEq::eq` at runtime.
329             ty::Adt(..) if !self.type_marked_structural(cv.ty()) && self.behind_reference.get() => {
330                 if self.include_lint_checks
331                     && !self.saw_const_match_error.get()
332                     && !self.saw_const_match_lint.get()
333                 {
334                     self.saw_const_match_lint.set(true);
335                     tcx.struct_span_lint_hir(
336                         lint::builtin::INDIRECT_STRUCTURAL_MATCH,
337                         id,
338                         span,
339                         DelayDm(|| {
340                             format!(
341                                 "to use a constant of type `{}` in a pattern, \
342                                  `{}` must be annotated with `#[derive(PartialEq, Eq)]`",
343                                 cv.ty(),
344                                 cv.ty(),
345                             )
346                         }),
347                         |lint| lint,
348                     );
349                 }
350                 // Since we are behind a reference, we can just bubble the error up so we get a
351                 // constant at reference type, making it easy to let the fallback call
352                 // `PartialEq::eq` on it.
353                 return Err(fallback_to_const_ref(self));
354             }
355             ty::Adt(adt_def, _) if !self.type_marked_structural(cv.ty()) => {
356                 debug!(
357                     "adt_def {:?} has !type_marked_structural for cv.ty: {:?}",
358                     adt_def,
359                     cv.ty()
360                 );
361                 let path = tcx.def_path_str(adt_def.did());
362                 let msg = format!(
363                     "to use a constant of type `{}` in a pattern, \
364                      `{}` must be annotated with `#[derive(PartialEq, Eq)]`",
365                     path, path,
366                 );
367                 self.saw_const_match_error.set(true);
368                 if self.include_lint_checks {
369                     tcx.sess.span_err(span, &msg);
370                 } else {
371                     tcx.sess.delay_span_bug(span, &msg);
372                 }
373                 PatKind::Wild
374             }
375             ty::Adt(adt_def, substs) if adt_def.is_enum() => {
376                 let destructured = tcx.destructure_mir_constant(param_env, cv);
377
378                 PatKind::Variant {
379                     adt_def: *adt_def,
380                     substs,
381                     variant_index: destructured
382                         .variant
383                         .expect("destructed const of adt without variant id"),
384                     subpatterns: self.field_pats(destructured.fields.iter().copied())?,
385                 }
386             }
387             ty::Tuple(_) | ty::Adt(_, _) => {
388                 let destructured = tcx.destructure_mir_constant(param_env, cv);
389                 PatKind::Leaf { subpatterns: self.field_pats(destructured.fields.iter().copied())? }
390             }
391             ty::Array(..) => PatKind::Array {
392                 prefix: tcx
393                     .destructure_mir_constant(param_env, cv)
394                     .fields
395                     .iter()
396                     .map(|val| self.recur(*val, false))
397                     .collect::<Result<_, _>>()?,
398                 slice: None,
399                 suffix: Box::new([]),
400             },
401             ty::Ref(_, pointee_ty, ..) => match *pointee_ty.kind() {
402                 // These are not allowed and will error elsewhere anyway.
403                 ty::Dynamic(..) => {
404                     self.saw_const_match_error.set(true);
405                     let msg = format!("`{}` cannot be used in patterns", cv.ty());
406                     if self.include_lint_checks {
407                         tcx.sess.span_err(span, &msg);
408                     } else {
409                         tcx.sess.delay_span_bug(span, &msg);
410                     }
411                     PatKind::Wild
412                 }
413                 // `&str` is represented as `ConstValue::Slice`, let's keep using this
414                 // optimization for now.
415                 ty::Str => PatKind::Constant { value: cv },
416                 // `b"foo"` produces a `&[u8; 3]`, but you can't use constants of array type when
417                 // matching against references, you can only use byte string literals.
418                 // The typechecker has a special case for byte string literals, by treating them
419                 // as slices. This means we turn `&[T; N]` constants into slice patterns, which
420                 // has no negative effects on pattern matching, even if we're actually matching on
421                 // arrays.
422                 ty::Array(..) if !self.treat_byte_string_as_slice => {
423                     let old = self.behind_reference.replace(true);
424                     let array = tcx.deref_mir_constant(self.param_env.and(cv));
425                     let val = PatKind::Deref {
426                         subpattern: Box::new(Pat {
427                             kind: PatKind::Array {
428                                 prefix: tcx
429                                     .destructure_mir_constant(param_env, array)
430                                     .fields
431                                     .iter()
432                                     .map(|val| self.recur(*val, false))
433                                     .collect::<Result<_, _>>()?,
434                                 slice: None,
435                                 suffix: Box::new([]),
436                             },
437                             span,
438                             ty: *pointee_ty,
439                         }),
440                     };
441                     self.behind_reference.set(old);
442                     val
443                 }
444                 ty::Array(elem_ty, _) |
445                 // Cannot merge this with the catch all branch below, because the `const_deref`
446                 // changes the type from slice to array, we need to keep the original type in the
447                 // pattern.
448                 ty::Slice(elem_ty) => {
449                     let old = self.behind_reference.replace(true);
450                     let array = tcx.deref_mir_constant(self.param_env.and(cv));
451                     let val = PatKind::Deref {
452                         subpattern: Box::new(Pat {
453                             kind: PatKind::Slice {
454                                 prefix: tcx
455                                     .destructure_mir_constant(param_env, array)
456                                     .fields
457                                     .iter()
458                                     .map(|val| self.recur(*val, false))
459                                     .collect::<Result<_, _>>()?,
460                                 slice: None,
461                                 suffix: Box::new([]),
462                             },
463                             span,
464                             ty: tcx.mk_slice(elem_ty),
465                         }),
466                     };
467                     self.behind_reference.set(old);
468                     val
469                 }
470                 // Backwards compatibility hack: support references to non-structural types.
471                 // We'll lower
472                 // this pattern to a `PartialEq::eq` comparison and `PartialEq::eq` takes a
473                 // reference. This makes the rest of the matching logic simpler as it doesn't have
474                 // to figure out how to get a reference again.
475                 ty::Adt(adt_def, _) if !self.type_marked_structural(*pointee_ty) => {
476                     if self.behind_reference.get() {
477                         if self.include_lint_checks
478                             && !self.saw_const_match_error.get()
479                             && !self.saw_const_match_lint.get()
480                         {
481                             self.saw_const_match_lint.set(true);
482                             let msg = self.adt_derive_msg(adt_def);
483                             self.tcx().struct_span_lint_hir(
484                                 lint::builtin::INDIRECT_STRUCTURAL_MATCH,
485                                 self.id,
486                                 self.span,
487                                 msg,
488                                 |lint| lint,
489                             );
490                         }
491                         PatKind::Constant { value: cv }
492                     } else {
493                         if !self.saw_const_match_error.get() {
494                             self.saw_const_match_error.set(true);
495                             let msg = self.adt_derive_msg(adt_def);
496                             if self.include_lint_checks {
497                                 tcx.sess.span_err(span, &msg);
498                             } else {
499                                 tcx.sess.delay_span_bug(span, &msg);
500                             }
501                         }
502                         PatKind::Wild
503                     }
504                 }
505                 // All other references are converted into deref patterns and then recursively
506                 // convert the dereferenced constant to a pattern that is the sub-pattern of the
507                 // deref pattern.
508                 _ => {
509                     if !pointee_ty.is_sized(tcx, param_env) {
510                         // `tcx.deref_mir_constant()` below will ICE with an unsized type
511                         // (except slices, which are handled in a separate arm above).
512                         let msg = format!("cannot use unsized non-slice type `{}` in constant patterns", pointee_ty);
513                         if self.include_lint_checks {
514                             tcx.sess.span_err(span, &msg);
515                         } else {
516                             tcx.sess.delay_span_bug(span, &msg);
517                         }
518                         PatKind::Wild
519                     } else {
520                         let old = self.behind_reference.replace(true);
521                         // In case there are structural-match violations somewhere in this subpattern,
522                         // we fall back to a const pattern. If we do not do this, we may end up with
523                         // a !structural-match constant that is not of reference type, which makes it
524                         // very hard to invoke `PartialEq::eq` on it as a fallback.
525                         let val = match self.recur(tcx.deref_mir_constant(self.param_env.and(cv)), false) {
526                             Ok(subpattern) => PatKind::Deref { subpattern },
527                             Err(_) => PatKind::Constant { value: cv },
528                         };
529                         self.behind_reference.set(old);
530                         val
531                     }
532                 }
533             },
534             ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::FnDef(..) => {
535                 PatKind::Constant { value: cv }
536             }
537             ty::RawPtr(pointee) if pointee.ty.is_sized(tcx, param_env) => {
538                 PatKind::Constant { value: cv }
539             }
540             // FIXME: these can have very surprising behaviour where optimization levels or other
541             // compilation choices change the runtime behaviour of the match.
542             // See https://github.com/rust-lang/rust/issues/70861 for examples.
543             ty::FnPtr(..) | ty::RawPtr(..) => {
544                 if self.include_lint_checks
545                     && !self.saw_const_match_error.get()
546                     && !self.saw_const_match_lint.get()
547                 {
548                     self.saw_const_match_lint.set(true);
549                     let msg = "function pointers and unsized pointers in patterns behave \
550                         unpredictably and should not be relied upon. \
551                         See https://github.com/rust-lang/rust/issues/70861 for details.";
552                     tcx.struct_span_lint_hir(
553                         lint::builtin::POINTER_STRUCTURAL_MATCH,
554                         id,
555                         span,
556                         msg,
557                         |lint| lint,
558                     );
559                 }
560                 PatKind::Constant { value: cv }
561             }
562             _ => {
563                 self.saw_const_match_error.set(true);
564                 let msg = format!("`{}` cannot be used in patterns", cv.ty());
565                 if self.include_lint_checks {
566                     tcx.sess.span_err(span, &msg);
567                 } else {
568                     tcx.sess.delay_span_bug(span, &msg);
569                 }
570                 PatKind::Wild
571             }
572         };
573
574         if self.include_lint_checks
575             && !self.saw_const_match_error.get()
576             && !self.saw_const_match_lint.get()
577             && mir_structural_match_violation
578             // FIXME(#73448): Find a way to bring const qualification into parity with
579             // `search_for_structural_match_violation` and then remove this condition.
580             && self.search_for_structural_match_violation(cv.ty()).is_some()
581         {
582             self.saw_const_match_lint.set(true);
583             // Obtain the actual type that isn't annotated. If we just looked at `cv.ty` we
584             // could get `Option<NonStructEq>`, even though `Option` is annotated with derive.
585             let msg = self.search_for_structural_match_violation(cv.ty()).unwrap().replace(
586                 "in a pattern,",
587                 "in a pattern, the constant's initializer must be trivial or",
588             );
589             tcx.struct_span_lint_hir(
590                 lint::builtin::NONTRIVIAL_STRUCTURAL_MATCH,
591                 id,
592                 span,
593                 msg,
594                 |lint| lint,
595             );
596         }
597
598         Ok(Box::new(Pat { span, ty: cv.ty(), kind }))
599     }
600 }