]> git.lizzy.rs Git - rust.git/blobdiff - compiler/rustc_trait_selection/src/traits/select/mod.rs
Auto merge of #107443 - cjgillot:generator-less-query, r=compiler-errors
[rust.git] / compiler / rustc_trait_selection / src / traits / select / mod.rs
index f90da95d51668b56cfbb582cbae2231ef6db371f..b8f5aeee2d593f4386107e3d11b4f9a2144cffa4 100644 (file)
@@ -38,6 +38,8 @@
 use rustc_hir as hir;
 use rustc_hir::def_id::DefId;
 use rustc_infer::infer::LateBoundRegionConversionTime;
+use rustc_infer::traits::TraitEngine;
+use rustc_infer::traits::TraitEngineExt;
 use rustc_middle::dep_graph::{DepKind, DepNodeIndex};
 use rustc_middle::mir::interpret::ErrorHandled;
 use rustc_middle::ty::abstract_const::NotConstEvaluatable;
@@ -47,6 +49,7 @@
 use rustc_middle::ty::SubstsRef;
 use rustc_middle::ty::{self, EarlyBinder, PolyProjectionPredicate, ToPolyTraitRef, ToPredicate};
 use rustc_middle::ty::{Ty, TyCtxt, TypeFoldable, TypeVisitable};
+use rustc_session::config::TraitSolver;
 use rustc_span::symbol::sym;
 
 use std::cell::{Cell, RefCell};
@@ -544,10 +547,14 @@ pub fn evaluate_root_obligation(
         obligation: &PredicateObligation<'tcx>,
     ) -> Result<EvaluationResult, OverflowError> {
         self.evaluation_probe(|this| {
-            this.evaluate_predicate_recursively(
-                TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
-                obligation.clone(),
-            )
+            if this.tcx().sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
+                this.evaluate_predicate_recursively(
+                    TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
+                    obligation.clone(),
+                )
+            } else {
+                this.evaluate_predicates_recursively_in_new_solver([obligation.clone()])
+            }
         })
     }
 
@@ -586,18 +593,40 @@ fn evaluate_predicates_recursively<'o, I>(
     where
         I: IntoIterator<Item = PredicateObligation<'tcx>> + std::fmt::Debug,
     {
-        let mut result = EvaluatedToOk;
-        for obligation in predicates {
-            let eval = self.evaluate_predicate_recursively(stack, obligation.clone())?;
-            if let EvaluatedToErr = eval {
-                // fast-path - EvaluatedToErr is the top of the lattice,
-                // so we don't need to look on the other predicates.
-                return Ok(EvaluatedToErr);
-            } else {
-                result = cmp::max(result, eval);
+        if self.tcx().sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
+            let mut result = EvaluatedToOk;
+            for obligation in predicates {
+                let eval = self.evaluate_predicate_recursively(stack, obligation.clone())?;
+                if let EvaluatedToErr = eval {
+                    // fast-path - EvaluatedToErr is the top of the lattice,
+                    // so we don't need to look on the other predicates.
+                    return Ok(EvaluatedToErr);
+                } else {
+                    result = cmp::max(result, eval);
+                }
             }
+            Ok(result)
+        } else {
+            self.evaluate_predicates_recursively_in_new_solver(predicates)
         }
-        Ok(result)
+    }
+
+    /// Evaluates the predicates using the new solver when `-Ztrait-solver=next` is enabled
+    fn evaluate_predicates_recursively_in_new_solver(
+        &mut self,
+        predicates: impl IntoIterator<Item = PredicateObligation<'tcx>>,
+    ) -> Result<EvaluationResult, OverflowError> {
+        let mut fulfill_cx = crate::solve::FulfillmentCtxt::new();
+        fulfill_cx.register_predicate_obligations(self.infcx, predicates);
+        // True errors
+        if !fulfill_cx.select_where_possible(self.infcx).is_empty() {
+            return Ok(EvaluatedToErr);
+        }
+        if !fulfill_cx.select_all_or_error(self.infcx).is_empty() {
+            return Ok(EvaluatedToAmbig);
+        }
+        // Regions and opaques are handled in the `evaluation_probe` by looking at the snapshot
+        Ok(EvaluatedToOk)
     }
 
     #[instrument(
@@ -768,7 +797,7 @@ fn evaluate_predicate_recursively<'o>(
                 }
 
                 ty::PredicateKind::ObjectSafe(trait_def_id) => {
-                    if self.tcx().is_object_safe(trait_def_id) {
+                    if self.tcx().check_is_object_safe(trait_def_id) {
                         Ok(EvaluatedToOk)
                     } else {
                         Ok(EvaluatedToErr)
@@ -2037,6 +2066,7 @@ fn sized_conditions(
             | ty::Ref(..)
             | ty::Generator(..)
             | ty::GeneratorWitness(..)
+            | ty::GeneratorWitnessMIR(..)
             | ty::Array(..)
             | ty::Closure(..)
             | ty::Never
@@ -2153,6 +2183,16 @@ fn copy_clone_conditions(
                 Where(ty::Binder::bind_with_vars(witness_tys.to_vec(), all_vars))
             }
 
+            ty::GeneratorWitnessMIR(def_id, ref substs) => {
+                let hidden_types = bind_generator_hidden_types_above(
+                    self.infcx,
+                    def_id,
+                    substs,
+                    obligation.predicate.bound_vars(),
+                );
+                Where(hidden_types)
+            }
+
             ty::Closure(_, substs) => {
                 // (*) binder moved here
                 let ty = self.infcx.shallow_resolve(substs.as_closure().tupled_upvars_ty());
@@ -2250,6 +2290,10 @@ fn constituent_types_for_ty(
                 types.map_bound(|types| types.to_vec())
             }
 
+            ty::GeneratorWitnessMIR(def_id, ref substs) => {
+                bind_generator_hidden_types_above(self.infcx, def_id, substs, t.bound_vars())
+            }
+
             // For `PhantomData<T>`, we pass `T`.
             ty::Adt(def, substs) if def.is_phantom_data() => t.rebind(substs.types().collect()),
 
@@ -2892,3 +2936,56 @@ pub enum ProjectionMatchesProjection {
     Ambiguous,
     No,
 }
+
+/// Replace all regions inside the generator interior with late bound regions.
+/// Note that each region slot in the types gets a new fresh late bound region, which means that
+/// none of the regions inside relate to any other, even if typeck had previously found constraints
+/// that would cause them to be related.
+#[instrument(level = "trace", skip(infcx), ret)]
+fn bind_generator_hidden_types_above<'tcx>(
+    infcx: &InferCtxt<'tcx>,
+    def_id: DefId,
+    substs: ty::SubstsRef<'tcx>,
+    bound_vars: &ty::List<ty::BoundVariableKind>,
+) -> ty::Binder<'tcx, Vec<Ty<'tcx>>> {
+    let tcx = infcx.tcx;
+    let mut seen_tys = FxHashSet::default();
+
+    let considering_regions = infcx.considering_regions;
+
+    let num_bound_variables = bound_vars.len() as u32;
+    let mut counter = num_bound_variables;
+
+    let hidden_types: Vec<_> = tcx
+        .generator_hidden_types(def_id)
+        // Deduplicate tys to avoid repeated work.
+        .filter(|bty| seen_tys.insert(*bty))
+        .map(|bty| {
+            let mut ty = bty.subst(tcx, substs);
+
+            // Only remap erased regions if we use them.
+            if considering_regions {
+                ty = tcx.fold_regions(ty, |mut r, current_depth| {
+                    if let ty::ReErased = r.kind() {
+                        let br = ty::BoundRegion {
+                            var: ty::BoundVar::from_u32(counter),
+                            kind: ty::BrAnon(counter, None),
+                        };
+                        counter += 1;
+                        r = tcx.mk_region(ty::ReLateBound(current_depth, br));
+                    }
+                    r
+                })
+            }
+
+            ty
+        })
+        .collect();
+    if considering_regions {
+        debug_assert!(!hidden_types.has_erased_regions());
+    }
+    let bound_vars = tcx.mk_bound_variable_kinds(bound_vars.iter().chain(
+        (num_bound_variables..counter).map(|i| ty::BoundVariableKind::Region(ty::BrAnon(i, None))),
+    ));
+    ty::Binder::bind_with_vars(hidden_types, bound_vars)
+}