]> git.lizzy.rs Git - rust.git/commitdiff
Retry canonical trait query in standard mode if overflow occurs
authorAravind Gollakota <aravindprasant@gmail.com>
Thu, 19 Apr 2018 08:15:36 +0000 (03:15 -0500)
committerAravind Gollakota <aravindprasant@gmail.com>
Fri, 27 Apr 2018 01:28:30 +0000 (20:28 -0500)
This is slightly hacky and hopefully only a somewhat temporary solution.

src/librustc/traits/query/evaluate_obligation.rs
src/librustc/traits/select.rs
src/librustc/ty/maps/mod.rs
src/librustc_traits/evaluate_obligation.rs

index 88c51d006db06ee313aeff3042fdc8f2c9996ffc..4e028cac49abe18474503d907447d865a88038f5 100644 (file)
@@ -10,7 +10,8 @@
 
 use infer::InferCtxt;
 use infer::canonical::{Canonical, Canonicalize};
-use traits::{EvaluationResult, PredicateObligation};
+use traits::{EvaluationResult, PredicateObligation, SelectionContext,
+             TraitQueryMode, OverflowError};
 use traits::query::CanonicalPredicateGoal;
 use ty::{ParamEnvAnd, Predicate, TyCtxt};
 
@@ -21,10 +22,7 @@ pub fn predicate_may_hold(
         &self,
         obligation: &PredicateObligation<'tcx>,
     ) -> bool {
-        let (c_pred, _) =
-            self.canonicalize_query(&obligation.param_env.and(obligation.predicate));
-
-        self.tcx.global_tcx().evaluate_obligation(c_pred).may_apply()
+        self.evaluate_obligation(obligation).may_apply()
     }
 
     /// Evaluates whether the predicate can be satisfied in the given
@@ -34,11 +32,29 @@ pub fn predicate_must_hold(
         &self,
         obligation: &PredicateObligation<'tcx>,
     ) -> bool {
+        self.evaluate_obligation(obligation) == EvaluationResult::EvaluatedToOk
+    }
+
+    // Helper function that canonicalizes and runs the query, as well as handles
+    // overflow.
+    fn evaluate_obligation(
+        &self,
+        obligation: &PredicateObligation<'tcx>,
+    ) -> EvaluationResult {
         let (c_pred, _) =
             self.canonicalize_query(&obligation.param_env.and(obligation.predicate));
-
-        self.tcx.global_tcx().evaluate_obligation(c_pred) ==
-            EvaluationResult::EvaluatedToOk
+        // Run canonical query. If overflow occurs, rerun from scratch but this time
+        // in standard trait query mode so that overflow is handled appropriately
+        // within `SelectionContext`.
+        match self.tcx.global_tcx().evaluate_obligation(c_pred) {
+            Ok(result) => result,
+            Err(OverflowError) => {
+                let mut selcx =
+                    SelectionContext::with_query_mode(&self, TraitQueryMode::Standard);
+                selcx.evaluate_obligation_recursively(obligation)
+                     .expect("Overflow should be caught earlier in standard query mode")
+            }
+        }
     }
 }
 
index fdf6dcf4bf37d9096f05b11a687138fe1f33b21a..4ba3655bb644abec51a0c615aedd4d0d58e55ba9 100644 (file)
@@ -425,6 +425,8 @@ fn is_stack_dependent(self) -> bool {
 /// Indicates that trait evaluation caused overflow.
 pub struct OverflowError;
 
+impl_stable_hash_for!(struct OverflowError { });
+
 impl<'tcx> From<OverflowError> for SelectionError<'tcx> {
     fn from(OverflowError: OverflowError) -> SelectionError<'tcx> {
         SelectionError::Overflow
@@ -568,20 +570,23 @@ pub fn select(&mut self, obligation: &TraitObligation<'tcx>)
 
         let stack = self.push_stack(TraitObligationStackList::empty(), obligation);
 
-        // `select` is currently only called in standard query mode
-        assert!(self.query_mode == TraitQueryMode::Standard);
-
         let candidate = match self.candidate_from_obligation(&stack) {
-            Err(SelectionError::Overflow) =>
-                bug!("Overflow should be caught earlier in standard query mode"),
+            Err(SelectionError::Overflow) => {
+                // In standard mode, overflow must have been caught and reported
+                // earlier.
+                assert!(self.query_mode == TraitQueryMode::Canonical);
+                return Err(SelectionError::Overflow);
+            },
             Err(e) => { return Err(e); },
             Ok(None) => { return Ok(None); },
             Ok(Some(candidate)) => candidate
         };
 
         match self.confirm_candidate(obligation, candidate) {
-            Err(SelectionError::Overflow) =>
-                bug!("Overflow should be caught earlier in standard query mode"),
+            Err(SelectionError::Overflow) => {
+                assert!(self.query_mode == TraitQueryMode::Canonical);
+                return Err(SelectionError::Overflow);
+            },
             Err(e) => Err(e),
             Ok(candidate) => Ok(Some(candidate))
         }
index 9343eccd38e9f059887df1903dd8d8ce61105516..cb929225bcdcfd59486b8bf8d4eccc94786885a2 100644 (file)
 
     /// Do not call this query directly: invoke `infcx.predicate_may_hold()` or
     /// `infcx.predicate_must_hold()` instead.
-    [] fn evaluate_obligation:
-        EvaluateObligation(CanonicalPredicateGoal<'tcx>) -> traits::EvaluationResult,
+    [] fn evaluate_obligation: EvaluateObligation(
+        CanonicalPredicateGoal<'tcx>
+    ) -> Result<traits::EvaluationResult, traits::OverflowError>,
 
     [] fn substitute_normalize_and_test_predicates:
         substitute_normalize_and_test_predicates_node((DefId, &'tcx Substs<'tcx>)) -> bool,
index f346bb8dc996b2a1021ce0a69d76ac9d1f4969c3..21259bbcd38ff9df779a2e9d7923079e05126fe7 100644 (file)
@@ -17,7 +17,7 @@
 crate fn evaluate_obligation<'tcx>(
     tcx: TyCtxt<'_, 'tcx, 'tcx>,
     goal: CanonicalPredicateGoal<'tcx>,
-) -> EvaluationResult {
+) -> Result<EvaluationResult, OverflowError> {
     tcx.infer_ctxt().enter(|ref infcx| {
         let (
             ParamEnvAnd {
         let mut selcx = SelectionContext::with_query_mode(&infcx, TraitQueryMode::Canonical);
         let obligation = Obligation::new(ObligationCause::dummy(), param_env, predicate);
 
-        match selcx.evaluate_obligation_recursively(&obligation) {
-            Ok(result) => result,
-            Err(OverflowError) => {
-                infcx.report_overflow_error(&obligation, true)
-            }
-        }
+        selcx.evaluate_obligation_recursively(&obligation)
     })
 }