]> git.lizzy.rs Git - rust.git/blobdiff - src/librustc_infer/infer/region_constraints/mod.rs
refactor: Extract the undo log to its own modules
[rust.git] / src / librustc_infer / infer / region_constraints / mod.rs
index 2be6ec4481c6be8c0e39120f78cb8041a9b415ef..ead2494756ce29d23baae6cbd9970c8f47afaf8e 100644 (file)
@@ -4,11 +4,15 @@
 use self::UndoLog::*;
 
 use super::unify_key;
-use super::{MiscVariable, RegionVariableOrigin, SubregionOrigin};
+use super::{
+    InferCtxtUndoLogs, MiscVariable, RegionVariableOrigin, Rollback, Snapshot, SubregionOrigin,
+};
 
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_data_structures::sync::Lrc;
+use rustc_data_structures::undo_log::UndoLogs;
 use rustc_data_structures::unify as ut;
+use rustc_data_structures::unify::UnifyKey;
 use rustc_hir::def_id::DefId;
 use rustc_index::vec::IndexVec;
 use rustc_middle::ty::ReStatic;
@@ -26,7 +30,7 @@
 pub use rustc_middle::infer::MemberConstraint;
 
 #[derive(Default)]
-pub struct RegionConstraintCollector<'tcx> {
+pub struct RegionConstraintStorage<'tcx> {
     /// For each `RegionVid`, the corresponding `RegionVariableOrigin`.
     var_infos: IndexVec<RegionVid, RegionVariableInfo>,
 
@@ -42,20 +46,6 @@ pub struct RegionConstraintCollector<'tcx> {
     /// exist). This prevents us from making many such regions.
     glbs: CombineMap<'tcx>,
 
-    /// The undo log records actions that might later be undone.
-    ///
-    /// Note: `num_open_snapshots` is used to track if we are actively
-    /// snapshotting. When the `start_snapshot()` method is called, we
-    /// increment `num_open_snapshots` to indicate that we are now actively
-    /// snapshotting. The reason for this is that otherwise we end up adding
-    /// entries for things like the lower bound on a variable and so forth,
-    /// which can never be rolled back.
-    undo_log: Vec<UndoLog<'tcx>>,
-
-    /// The number of open snapshots, i.e., those that haven't been committed or
-    /// rolled back.
-    num_open_snapshots: usize,
-
     /// When we add a R1 == R2 constriant, we currently add (a) edges
     /// R1 <= R2 and R2 <= R1 and (b) we unify the two regions in this
     /// table. You can then call `opportunistic_resolve_var` early
@@ -64,13 +54,31 @@ pub struct RegionConstraintCollector<'tcx> {
     /// is iterating to a fixed point, because otherwise we sometimes
     /// would wind up with a fresh stream of region variables that
     /// have been equated but appear distinct.
-    unification_table: ut::UnificationTable<ut::InPlace<ty::RegionVid>>,
+    pub(super) unification_table: ut::UnificationStorage<ty::RegionVid>,
 
     /// a flag set to true when we perform any unifications; this is used
     /// to micro-optimize `take_and_reset_data`
     any_unifications: bool,
 }
 
+pub struct RegionConstraintCollector<'tcx, 'a> {
+    storage: &'a mut RegionConstraintStorage<'tcx>,
+    undo_log: &'a mut InferCtxtUndoLogs<'tcx>,
+}
+
+impl std::ops::Deref for RegionConstraintCollector<'tcx, '_> {
+    type Target = RegionConstraintStorage<'tcx>;
+    fn deref(&self) -> &RegionConstraintStorage<'tcx> {
+        self.storage
+    }
+}
+
+impl std::ops::DerefMut for RegionConstraintCollector<'tcx, '_> {
+    fn deref_mut(&mut self) -> &mut RegionConstraintStorage<'tcx> {
+        self.storage
+    }
+}
+
 pub type VarInfos = IndexVec<RegionVid, RegionVariableInfo>;
 
 /// The full set of region constraints gathered up by the collector.
@@ -258,13 +266,13 @@ pub enum VerifyBound<'tcx> {
 }
 
 #[derive(Copy, Clone, PartialEq, Eq, Hash)]
-struct TwoRegions<'tcx> {
+pub(crate) struct TwoRegions<'tcx> {
     a: Region<'tcx>,
     b: Region<'tcx>,
 }
 
 #[derive(Copy, Clone, PartialEq)]
-enum UndoLog<'tcx> {
+pub(crate) enum UndoLog<'tcx> {
     /// We added `RegionVid`.
     AddVar(RegionVid),
 
@@ -290,7 +298,7 @@ enum UndoLog<'tcx> {
 }
 
 #[derive(Copy, Clone, PartialEq)]
-enum CombineMapType {
+pub(crate) enum CombineMapType {
     Lub,
     Glb,
 }
@@ -304,8 +312,7 @@ pub struct RegionVariableInfo {
 }
 
 pub struct RegionSnapshot {
-    length: usize,
-    region_snapshot: ut::Snapshot<ut::InPlace<ty::RegionVid>>,
+    value_count: usize,
     any_unifications: bool,
 }
 
@@ -334,11 +341,48 @@ pub fn both() -> Self {
     }
 }
 
-impl<'tcx> RegionConstraintCollector<'tcx> {
+impl<'tcx> RegionConstraintStorage<'tcx> {
     pub fn new() -> Self {
         Self::default()
     }
 
+    pub(crate) fn with_log<'a>(
+        &'a mut self,
+        undo_log: &'a mut InferCtxtUndoLogs<'tcx>,
+    ) -> RegionConstraintCollector<'tcx, 'a> {
+        RegionConstraintCollector { storage: self, undo_log }
+    }
+
+    fn rollback_undo_entry(&mut self, undo_entry: UndoLog<'tcx>) {
+        match undo_entry {
+            Purged => {
+                // nothing to do here
+            }
+            AddVar(vid) => {
+                self.var_infos.pop().unwrap();
+                assert_eq!(self.var_infos.len(), vid.index() as usize);
+            }
+            AddConstraint(ref constraint) => {
+                self.data.constraints.remove(constraint);
+            }
+            AddVerify(index) => {
+                self.data.verifys.pop();
+                assert_eq!(self.data.verifys.len(), index);
+            }
+            AddGiven(sub, sup) => {
+                self.data.givens.remove(&(sub, sup));
+            }
+            AddCombination(Glb, ref regions) => {
+                self.glbs.remove(regions);
+            }
+            AddCombination(Lub, ref regions) => {
+                self.lubs.remove(regions);
+            }
+        }
+    }
+}
+
+impl<'tcx> RegionConstraintCollector<'tcx, '_> {
     pub fn num_region_vars(&self) -> usize {
         self.var_infos.len()
     }
@@ -351,8 +395,8 @@ pub fn region_constraint_data(&self) -> &RegionConstraintData<'tcx> {
     ///
     /// Not legal during a snapshot.
     pub fn into_infos_and_data(self) -> (VarInfos, RegionConstraintData<'tcx>) {
-        assert!(!self.in_snapshot());
-        (self.var_infos, self.data)
+        assert!(!UndoLogs::<super::UndoLog<'_>>::in_snapshot(&self.undo_log));
+        (mem::take(&mut self.storage.var_infos), mem::take(&mut self.storage.data))
     }
 
     /// Takes (and clears) the current set of constraints. Note that
@@ -368,21 +412,19 @@ pub fn into_infos_and_data(self) -> (VarInfos, RegionConstraintData<'tcx>) {
     ///
     /// Not legal during a snapshot.
     pub fn take_and_reset_data(&mut self) -> RegionConstraintData<'tcx> {
-        assert!(!self.in_snapshot());
+        assert!(!UndoLogs::<super::UndoLog<'_>>::in_snapshot(&self.undo_log));
 
         // If you add a new field to `RegionConstraintCollector`, you
         // should think carefully about whether it needs to be cleared
         // or updated in some way.
-        let RegionConstraintCollector {
+        let RegionConstraintStorage {
             var_infos: _,
             data,
             lubs,
             glbs,
-            undo_log: _,
-            num_open_snapshots: _,
-            unification_table,
+            unification_table: _,
             any_unifications,
-        } = self;
+        } = self.storage;
 
         // Clear the tables of (lubs, glbs), so that we will create
         // fresh regions if we do a LUB operation. As it happens,
@@ -391,102 +433,38 @@ pub fn take_and_reset_data(&mut self) -> RegionConstraintData<'tcx> {
         lubs.clear();
         glbs.clear();
 
+        let data = mem::take(data);
+
         // Clear all unifications and recreate the variables a "now
         // un-unified" state. Note that when we unify `a` and `b`, we
         // also insert `a <= b` and a `b <= a` edges, so the
         // `RegionConstraintData` contains the relationship here.
         if *any_unifications {
-            unification_table.reset_unifications(|vid| unify_key::RegionVidKey { min_vid: vid });
             *any_unifications = false;
+            self.unification_table()
+                .reset_unifications(|vid| unify_key::RegionVidKey { min_vid: vid });
         }
 
-        mem::take(data)
+        data
     }
 
     pub fn data(&self) -> &RegionConstraintData<'tcx> {
         &self.data
     }
 
-    fn in_snapshot(&self) -> bool {
-        self.num_open_snapshots > 0
-    }
-
     pub fn start_snapshot(&mut self) -> RegionSnapshot {
-        let length = self.undo_log.len();
-        debug!("RegionConstraintCollector: start_snapshot({})", length);
-        self.num_open_snapshots += 1;
+        debug!("RegionConstraintCollector: start_snapshot");
         RegionSnapshot {
-            length,
-            region_snapshot: self.unification_table.snapshot(),
+            value_count: self.unification_table.len(),
             any_unifications: self.any_unifications,
         }
     }
 
-    fn assert_open_snapshot(&self, snapshot: &RegionSnapshot) {
-        assert!(self.undo_log.len() >= snapshot.length);
-        assert!(self.num_open_snapshots > 0);
-    }
-
-    pub fn commit(&mut self, snapshot: RegionSnapshot) {
-        debug!("RegionConstraintCollector: commit({})", snapshot.length);
-        self.assert_open_snapshot(&snapshot);
-
-        if self.num_open_snapshots == 1 {
-            // The root snapshot. It's safe to clear the undo log because
-            // there's no snapshot further out that we might need to roll back
-            // to.
-            assert!(snapshot.length == 0);
-            self.undo_log.clear();
-        }
-
-        self.num_open_snapshots -= 1;
-
-        self.unification_table.commit(snapshot.region_snapshot);
-    }
-
     pub fn rollback_to(&mut self, snapshot: RegionSnapshot) {
         debug!("RegionConstraintCollector: rollback_to({:?})", snapshot);
-        self.assert_open_snapshot(&snapshot);
-
-        while self.undo_log.len() > snapshot.length {
-            let undo_entry = self.undo_log.pop().unwrap();
-            self.rollback_undo_entry(undo_entry);
-        }
-
-        self.num_open_snapshots -= 1;
-
-        self.unification_table.rollback_to(snapshot.region_snapshot);
         self.any_unifications = snapshot.any_unifications;
     }
 
-    fn rollback_undo_entry(&mut self, undo_entry: UndoLog<'tcx>) {
-        match undo_entry {
-            Purged => {
-                // nothing to do here
-            }
-            AddVar(vid) => {
-                self.var_infos.pop().unwrap();
-                assert_eq!(self.var_infos.len(), vid.index() as usize);
-            }
-            AddConstraint(ref constraint) => {
-                self.data.constraints.remove(constraint);
-            }
-            AddVerify(index) => {
-                self.data.verifys.pop();
-                assert_eq!(self.data.verifys.len(), index);
-            }
-            AddGiven(sub, sup) => {
-                self.data.givens.remove(&(sub, sup));
-            }
-            AddCombination(Glb, ref regions) => {
-                self.glbs.remove(regions);
-            }
-            AddCombination(Lub, ref regions) => {
-                self.lubs.remove(regions);
-            }
-        }
-    }
-
     pub fn new_region_var(
         &mut self,
         universe: ty::UniverseIndex,
@@ -494,11 +472,9 @@ pub fn new_region_var(
     ) -> RegionVid {
         let vid = self.var_infos.push(RegionVariableInfo { origin, universe });
 
-        let u_vid = self.unification_table.new_key(unify_key::RegionVidKey { min_vid: vid });
+        let u_vid = self.unification_table().new_key(unify_key::RegionVidKey { min_vid: vid });
         assert_eq!(vid, u_vid);
-        if self.in_snapshot() {
-            self.undo_log.push(AddVar(vid));
-        }
+        self.undo_log.push(AddVar(vid));
         debug!("created new region variable {:?} in {:?} with origin {:?}", vid, universe, origin);
         vid
     }
@@ -520,19 +496,29 @@ pub fn var_origin(&self, vid: RegionVid) -> RegionVariableOrigin {
     pub fn pop_placeholders(&mut self, placeholders: &FxHashSet<ty::Region<'tcx>>) {
         debug!("pop_placeholders(placeholders={:?})", placeholders);
 
-        assert!(self.in_snapshot());
+        assert!(UndoLogs::<super::UndoLog<'_>>::in_snapshot(&self.undo_log));
 
         let constraints_to_kill: Vec<usize> = self
             .undo_log
             .iter()
             .enumerate()
             .rev()
-            .filter(|&(_, undo_entry)| kill_constraint(placeholders, undo_entry))
+            .filter(|&(_, undo_entry)| match undo_entry {
+                super::UndoLog::RegionConstraintCollector(undo_entry) => {
+                    kill_constraint(placeholders, undo_entry)
+                }
+                _ => false,
+            })
             .map(|(index, _)| index)
             .collect();
 
         for index in constraints_to_kill {
-            let undo_entry = mem::replace(&mut self.undo_log[index], Purged);
+            let undo_entry = match &mut self.undo_log[index] {
+                super::UndoLog::RegionConstraintCollector(undo_entry) => {
+                    mem::replace(undo_entry, Purged)
+                }
+                _ => unreachable!(),
+            };
             self.rollback_undo_entry(undo_entry);
         }
 
@@ -566,12 +552,9 @@ fn add_constraint(&mut self, constraint: Constraint<'tcx>, origin: SubregionOrig
         // never overwrite an existing (constraint, origin) - only insert one if it isn't
         // present in the map yet. This prevents origins from outside the snapshot being
         // replaced with "less informative" origins e.g., during calls to `can_eq`
-        let in_snapshot = self.in_snapshot();
         let undo_log = &mut self.undo_log;
-        self.data.constraints.entry(constraint).or_insert_with(|| {
-            if in_snapshot {
-                undo_log.push(AddConstraint(constraint));
-            }
+        self.storage.data.constraints.entry(constraint).or_insert_with(|| {
+            undo_log.push(AddConstraint(constraint));
             origin
         });
     }
@@ -589,9 +572,7 @@ fn add_verify(&mut self, verify: Verify<'tcx>) {
 
         let index = self.data.verifys.len();
         self.data.verifys.push(verify);
-        if self.in_snapshot() {
-            self.undo_log.push(AddVerify(index));
-        }
+        self.undo_log.push(AddVerify(index));
     }
 
     pub fn add_given(&mut self, sub: Region<'tcx>, sup: ty::RegionVid) {
@@ -599,9 +580,7 @@ pub fn add_given(&mut self, sub: Region<'tcx>, sup: ty::RegionVid) {
         if self.data.givens.insert((sub, sup)) {
             debug!("add_given({:?} <= {:?})", sub, sup);
 
-            if self.in_snapshot() {
-                self.undo_log.push(AddGiven(sub, sup));
-            }
+            self.undo_log.push(AddGiven(sub, sup));
         }
     }
 
@@ -619,7 +598,7 @@ pub fn make_eqregion(
 
             if let (ty::ReVar(sub), ty::ReVar(sup)) = (*sub, *sup) {
                 debug!("make_eqregion: uniying {:?} with {:?}", sub, sup);
-                self.unification_table.union(sub, sup);
+                self.unification_table().union(sub, sup);
                 self.any_unifications = true;
             }
         }
@@ -741,7 +720,7 @@ pub fn opportunistic_resolve_var(
         tcx: TyCtxt<'tcx>,
         rid: RegionVid,
     ) -> ty::Region<'tcx> {
-        let vid = self.unification_table.probe_value(rid).min_vid;
+        let vid = self.unification_table().probe_value(rid).min_vid;
         tcx.mk_region(ty::ReVar(vid))
     }
 
@@ -769,9 +748,7 @@ fn combine_vars(
         let c_universe = cmp::max(a_universe, b_universe);
         let c = self.new_region_var(c_universe, MiscVariable(origin.span()));
         self.combine_map(t).insert(vars, c);
-        if self.in_snapshot() {
-            self.undo_log.push(AddCombination(t, vars));
-        }
+        self.undo_log.push(AddCombination(t, vars));
         let new_r = tcx.mk_region(ReVar(c));
         for &old_r in &[a, b] {
             match t {
@@ -801,7 +778,8 @@ pub fn vars_since_snapshot(
         &self,
         mark: &RegionSnapshot,
     ) -> (Range<RegionVid>, Vec<RegionVariableOrigin>) {
-        let range = self.unification_table.vars_since_snapshot(&mark.region_snapshot);
+        let range = RegionVid::from_index(mark.value_count as u32)
+            ..RegionVid::from_index(self.unification_table.len() as u32);
         (
             range.clone(),
             (range.start.index()..range.end.index())
@@ -810,10 +788,10 @@ pub fn vars_since_snapshot(
         )
     }
 
-    /// See `InferCtxt::region_constraints_added_in_snapshot`.
-    pub fn region_constraints_added_in_snapshot(&self, mark: &RegionSnapshot) -> Option<bool> {
-        self.undo_log[mark.length..]
-            .iter()
+    /// See [`RegionInference::region_constraints_added_in_snapshot`].
+    pub fn region_constraints_added_in_snapshot(&self, mark: &Snapshot<'tcx>) -> Option<bool> {
+        self.undo_log
+            .region_constraints_in_snapshot(mark)
             .map(|&elt| match elt {
                 AddConstraint(constraint) => Some(constraint.involves_placeholders()),
                 _ => None,
@@ -821,11 +799,15 @@ pub fn region_constraints_added_in_snapshot(&self, mark: &RegionSnapshot) -> Opt
             .max()
             .unwrap_or(None)
     }
+
+    fn unification_table(&mut self) -> super::UnificationTable<'_, 'tcx, ty::RegionVid> {
+        ut::UnificationTable::with_log(&mut self.storage.unification_table, self.undo_log)
+    }
 }
 
 impl fmt::Debug for RegionSnapshot {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "RegionSnapshot(length={})", self.length)
+        write!(f, "RegionSnapshot")
     }
 }
 
@@ -910,3 +892,9 @@ pub fn is_empty(&self) -> bool {
             && givens.is_empty()
     }
 }
+
+impl<'tcx> Rollback<UndoLog<'tcx>> for RegionConstraintStorage<'tcx> {
+    fn reverse(&mut self, undo: UndoLog<'tcx>) {
+        self.rollback_undo_entry(undo)
+    }
+}