]> git.lizzy.rs Git - rust.git/commitdiff
Use now solver in evaluate_obligation
authorMichael Goulet <michael@errs.io>
Fri, 27 Jan 2023 17:46:18 +0000 (17:46 +0000)
committerMichael Goulet <michael@errs.io>
Fri, 27 Jan 2023 17:53:07 +0000 (17:53 +0000)
compiler/rustc_trait_selection/src/traits/query/evaluate_obligation.rs

index 09b58894d3040df0e09649406a5633b1b150235b..f183248f2d08b0474e6ef47b7b3ced870c95d53a 100644 (file)
@@ -1,7 +1,9 @@
 use rustc_middle::ty;
+use rustc_session::config::TraitSolver;
 
 use crate::infer::canonical::OriginalQueryValues;
 use crate::infer::InferCtxt;
+use crate::solve::{Certainty, Goal, InferCtxtEvalExt, MaybeCause};
 use crate::traits::{EvaluationResult, OverflowError, PredicateObligation, SelectionContext};
 
 pub trait InferCtxtExt<'tcx> {
@@ -77,12 +79,38 @@ fn evaluate_obligation(
             _ => obligation.param_env.without_const(),
         };
 
-        let c_pred = self
-            .canonicalize_query_keep_static(param_env.and(obligation.predicate), &mut _orig_values);
-        // Run canonical query. If overflow occurs, rerun from scratch but this time
-        // in standard trait query mode so that overflow is handled appropriately
-        // within `SelectionContext`.
-        self.tcx.at(obligation.cause.span()).evaluate_obligation(c_pred)
+        if self.tcx.sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
+            let c_pred = self.canonicalize_query_keep_static(
+                param_env.and(obligation.predicate),
+                &mut _orig_values,
+            );
+            self.tcx.at(obligation.cause.span()).evaluate_obligation(c_pred)
+        } else {
+            self.probe(|snapshot| {
+                if let Ok((_, certainty)) =
+                    self.evaluate_root_goal(Goal::new(self.tcx, param_env, obligation.predicate))
+                {
+                    match certainty {
+                        Certainty::Yes => {
+                            if self.opaque_types_added_in_snapshot(snapshot) {
+                                Ok(EvaluationResult::EvaluatedToOkModuloOpaqueTypes)
+                            } else if self.region_constraints_added_in_snapshot(snapshot).is_some()
+                            {
+                                Ok(EvaluationResult::EvaluatedToOkModuloRegions)
+                            } else {
+                                Ok(EvaluationResult::EvaluatedToOk)
+                            }
+                        }
+                        Certainty::Maybe(MaybeCause::Ambiguity) => {
+                            Ok(EvaluationResult::EvaluatedToAmbig)
+                        }
+                        Certainty::Maybe(MaybeCause::Overflow) => Err(OverflowError::Canonical),
+                    }
+                } else {
+                    Ok(EvaluationResult::EvaluatedToErr)
+                }
+            })
+        }
     }
 
     // Helper function that canonicalizes and runs the query. If an
@@ -92,6 +120,9 @@ fn evaluate_obligation_no_overflow(
         &self,
         obligation: &PredicateObligation<'tcx>,
     ) -> EvaluationResult {
+        // Run canonical query. If overflow occurs, rerun from scratch but this time
+        // in standard trait query mode so that overflow is handled appropriately
+        // within `SelectionContext`.
         match self.evaluate_obligation(obligation) {
             Ok(result) => result,
             Err(OverflowError::Canonical) => {