]> git.lizzy.rs Git - rust.git/commitdiff
Improve time complexity of equality relations
authorMarkus Westerlind <marwes91@gmail.com>
Thu, 3 Mar 2016 09:43:52 +0000 (10:43 +0100)
committerMarkus Westerlind <marwes91@gmail.com>
Mon, 21 Mar 2016 21:40:30 +0000 (22:40 +0100)
This PR adds a `UnificationTable` to the `TypeVariableTable` type which
is used store information about variable equality instead of just
storing them in a vector for later processing. By using a
`UnificationTable` equality relations can be resolved in O(n) (for all
realistic values of n) rather than O(n!) which can give massive
speedups in certain cases (see combine as an example).

Link to combine: https://github.com/Marwes/combine

12 files changed:
src/librustc/middle/infer/bivariate.rs
src/librustc/middle/infer/combine.rs
src/librustc/middle/infer/equate.rs
src/librustc/middle/infer/freshen.rs
src/librustc/middle/infer/higher_ranked/mod.rs
src/librustc/middle/infer/lattice.rs
src/librustc/middle/infer/mod.rs
src/librustc/middle/infer/sub.rs
src/librustc/middle/infer/type_variable.rs
src/librustc/middle/infer/unify_key.rs
src/librustc_data_structures/unify/mod.rs
src/test/run-pass/bench/issue-32062.rs [new file with mode: 0644]

index cb6542856be24aa51c32731b2c7dd71f44e55cb6..485b7d2a9dd50abb24a7b730e15454098527594a 100644 (file)
@@ -77,8 +77,8 @@ fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
         if a == b { return Ok(a); }
 
         let infcx = self.fields.infcx;
-        let a = infcx.type_variables.borrow().replace_if_possible(a);
-        let b = infcx.type_variables.borrow().replace_if_possible(b);
+        let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
+        let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
         match (&a.sty, &b.sty) {
             (&ty::TyInfer(TyVar(a_id)), &ty::TyInfer(TyVar(b_id))) => {
                 infcx.type_variables.borrow_mut().relate_vars(a_id, BiTo, b_id);
index cd4a2eb2d93b48279805b8425141c279e40078c5..1c2af96132559cbfd029fed2cc8e30eca249a5e7 100644 (file)
@@ -210,6 +210,12 @@ pub fn instantiate(&self,
                 None => break,
                 Some(e) => e,
             };
+            // Get the actual variable that b_vid has been inferred to
+            let (b_vid, b_ty) = {
+                let mut variables = self.infcx.type_variables.borrow_mut();
+                let b_vid = variables.root_var(b_vid);
+                (b_vid, variables.probe_root(b_vid))
+            };
 
             debug!("instantiate(a_ty={:?} dir={:?} b_vid={:?})",
                    a_ty,
@@ -219,7 +225,6 @@ pub fn instantiate(&self,
             // Check whether `vid` has been instantiated yet.  If not,
             // make a generalized form of `ty` and instantiate with
             // that.
-            let b_ty = self.infcx.type_variables.borrow().probe(b_vid);
             let b_ty = match b_ty {
                 Some(t) => t, // ...already instantiated.
                 None => {     // ...not yet instantiated:
@@ -307,12 +312,17 @@ fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
         //  where `$1` has already been instantiated with `Box<$0>`)
         match t.sty {
             ty::TyInfer(ty::TyVar(vid)) => {
+                let mut variables = self.infcx.type_variables.borrow_mut();
+                let vid = variables.root_var(vid);
                 if vid == self.for_vid {
                     self.cycle_detected = true;
                     self.tcx().types.err
                 } else {
-                    match self.infcx.type_variables.borrow().probe(vid) {
-                        Some(u) => self.fold_ty(u),
+                    match variables.probe_root(vid) {
+                        Some(u) => {
+                            drop(variables);
+                            self.fold_ty(u)
+                        }
                         None => t,
                     }
                 }
index a10568d1fa33a8c5f59b9a0effc694011044b9c5..92a419fec323c52caef9bd937d6e22290b465ce9 100644 (file)
@@ -50,8 +50,8 @@ fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
         if a == b { return Ok(a); }
 
         let infcx = self.fields.infcx;
-        let a = infcx.type_variables.borrow().replace_if_possible(a);
-        let b = infcx.type_variables.borrow().replace_if_possible(b);
+        let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
+        let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
         match (&a.sty, &b.sty) {
             (&ty::TyInfer(TyVar(a_id)), &ty::TyInfer(TyVar(b_id))) => {
                 infcx.type_variables.borrow_mut().relate_vars(a_id, EqTo, b_id);
index b64fa688d5163480d83014a6996c01b2675d2c83..a81ba03d9ca68467d59478e75eb198dfa9d7e92d 100644 (file)
@@ -111,8 +111,9 @@ fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
 
         match t.sty {
             ty::TyInfer(ty::TyVar(v)) => {
+                let opt_ty = self.infcx.type_variables.borrow_mut().probe(v);
                 self.freshen(
-                    self.infcx.type_variables.borrow().probe(v),
+                    opt_ty,
                     ty::TyVar(v),
                     ty::FreshTy)
             }
index 9b6625886a47c0de7da339d378d8aa64a5751e3f..6cb91438ec36635c5d374d86049087586a9196ca 100644 (file)
@@ -434,7 +434,7 @@ fn region_vars_confined_to_snapshot(&self,
             self.region_vars.vars_created_since_snapshot(&snapshot.region_vars_snapshot);
 
         let escaping_types =
-            self.type_variables.borrow().types_escaping_snapshot(&snapshot.type_snapshot);
+            self.type_variables.borrow_mut().types_escaping_snapshot(&snapshot.type_snapshot);
 
         let mut escaping_region_vars = FnvHashSet();
         for ty in &escaping_types {
index 2a560ec8a1d237dd497d70ef6fe7fc569c10dc75..6b5f2c74a69c642a8c35e9651475094d379ca5fa 100644 (file)
@@ -60,8 +60,8 @@ pub fn super_lattice_tys<'a,'tcx,L:LatticeDir<'a,'tcx>>(this: &mut L,
     }
 
     let infcx = this.infcx();
-    let a = infcx.type_variables.borrow().replace_if_possible(a);
-    let b = infcx.type_variables.borrow().replace_if_possible(b);
+    let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
+    let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
     match (&a.sty, &b.sty) {
         (&ty::TyInfer(TyVar(..)), &ty::TyInfer(TyVar(..)))
             if infcx.type_var_diverges(a) && infcx.type_var_diverges(b) => {
index b9a5b32b71d825e0626c52927c3f1819c230b1a7..a7e67c510727bce03ab6e5d3fe086f81701c5096 100644 (file)
@@ -637,7 +637,7 @@ pub fn unsolved_variables(&self) -> Vec<ty::Ty<'tcx>> {
         let mut variables = Vec::new();
 
         let unbound_ty_vars = self.type_variables
-                                  .borrow()
+                                  .borrow_mut()
                                   .unsolved_variables()
                                   .into_iter()
                                   .map(|t| self.tcx.mk_var(t));
@@ -1162,7 +1162,7 @@ pub fn shallow_resolve(&self, typ: Ty<'tcx>) -> Ty<'tcx> {
                 // structurally), and we prevent cycles in any case,
                 // so this recursion should always be of very limited
                 // depth.
-                self.type_variables.borrow()
+                self.type_variables.borrow_mut()
                     .probe(v)
                     .map(|t| self.shallow_resolve(t))
                     .unwrap_or(typ)
index e13d29b8b4215c7bbe6402d52dd503629c83386d..918a8c362da2d2e20f5fb66b2ff354dd7ac8b1e8 100644 (file)
@@ -65,8 +65,8 @@ fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
         if a == b { return Ok(a); }
 
         let infcx = self.fields.infcx;
-        let a = infcx.type_variables.borrow().replace_if_possible(a);
-        let b = infcx.type_variables.borrow().replace_if_possible(b);
+        let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
+        let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
         match (&a.sty, &b.sty) {
             (&ty::TyInfer(TyVar(a_id)), &ty::TyInfer(TyVar(b_id))) => {
                 infcx.type_variables
index e4af098c2a42d7bce9fdb869243f996570afad6e..fe66ea5a1ea124370e83a86e07fc1dc0f8f1320c 100644 (file)
 use std::mem;
 use std::u32;
 use rustc_data_structures::snapshot_vec as sv;
+use rustc_data_structures::unify as ut;
 
 pub struct TypeVariableTable<'tcx> {
     values: sv::SnapshotVec<Delegate<'tcx>>,
+    eq_relations: ut::UnificationTable<ty::TyVid>,
 }
 
 struct TypeVariableData<'tcx> {
@@ -50,20 +52,22 @@ pub struct Default<'tcx> {
 }
 
 pub struct Snapshot {
-    snapshot: sv::Snapshot
+    snapshot: sv::Snapshot,
+    eq_snapshot: ut::Snapshot<ty::TyVid>,
 }
 
 enum UndoEntry<'tcx> {
     // The type of the var was specified.
     SpecifyVar(ty::TyVid, Vec<Relation>, Option<Default<'tcx>>),
     Relate(ty::TyVid, ty::TyVid),
+    RelateRange(ty::TyVid, usize),
 }
 
 struct Delegate<'tcx>(PhantomData<&'tcx ()>);
 
 type Relation = (RelationDir, ty::TyVid);
 
-#[derive(Copy, Clone, PartialEq, Debug)]
+#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
 pub enum RelationDir {
     SubtypeOf, SupertypeOf, EqTo, BiTo
 }
@@ -81,7 +85,10 @@ fn opposite(self) -> RelationDir {
 
 impl<'tcx> TypeVariableTable<'tcx> {
     pub fn new() -> TypeVariableTable<'tcx> {
-        TypeVariableTable { values: sv::SnapshotVec::new() }
+        TypeVariableTable {
+            values: sv::SnapshotVec::new(),
+            eq_relations: ut::UnificationTable::new(),
+        }
     }
 
     fn relations<'a>(&'a mut self, a: ty::TyVid) -> &'a mut Vec<Relation> {
@@ -103,22 +110,48 @@ pub fn var_diverges<'a>(&'a self, vid: ty::TyVid) -> bool {
     ///
     /// Precondition: neither `a` nor `b` are known.
     pub fn relate_vars(&mut self, a: ty::TyVid, dir: RelationDir, b: ty::TyVid) {
+        let a = self.root_var(a);
+        let b = self.root_var(b);
         if a != b {
-            self.relations(a).push((dir, b));
-            self.relations(b).push((dir.opposite(), a));
-            self.values.record(Relate(a, b));
+            if dir == EqTo {
+                // a and b must be equal which we mark in the unification table
+                let root = self.eq_relations.union(a, b);
+                // In addition to being equal, all relations from the variable which is no longer
+                // the root must be added to the root so they are not forgotten as the other
+                // variable should no longer be referenced (other than to get the root)
+                let other = if a == root { b } else { a };
+                let count = {
+                    let (relations, root_relations) = if other.index < root.index {
+                        let (pre, post) = self.values.split_at_mut(root.index as usize);
+                        (relations(&mut pre[other.index as usize]), relations(&mut post[0]))
+                    } else {
+                        let (pre, post) = self.values.split_at_mut(other.index as usize);
+                        (relations(&mut post[0]), relations(&mut pre[root.index as usize]))
+                    };
+                    root_relations.extend_from_slice(relations);
+                    relations.len()
+                };
+                self.values.record(RelateRange(root, count));
+            } else {
+                self.relations(a).push((dir, b));
+                self.relations(b).push((dir.opposite(), a));
+                self.values.record(Relate(a, b));
+            }
         }
     }
 
     /// Instantiates `vid` with the type `ty` and then pushes an entry onto `stack` for each of the
     /// relations of `vid` to other variables. The relations will have the form `(ty, dir, vid1)`
     /// where `vid1` is some other variable id.
+    ///
+    /// Precondition: `vid` must be a root in the unification table
     pub fn instantiate_and_push(
         &mut self,
         vid: ty::TyVid,
         ty: Ty<'tcx>,
         stack: &mut Vec<(Ty<'tcx>, RelationDir, ty::TyVid)>)
     {
+        debug_assert!(self.root_var(vid) == vid);
         let old_value = {
             let value_ptr = &mut self.values.get_mut(vid.index as usize).value;
             mem::replace(value_ptr, Known(ty))
@@ -140,6 +173,7 @@ pub fn instantiate_and_push(
     pub fn new_var(&mut self,
                    diverging: bool,
                    default: Option<Default<'tcx>>) -> ty::TyVid {
+        self.eq_relations.new_key(());
         let index = self.values.push(TypeVariableData {
             value: Bounded { relations: vec![], default: default },
             diverging: diverging
@@ -147,14 +181,25 @@ pub fn new_var(&mut self,
         ty::TyVid { index: index as u32 }
     }
 
-    pub fn probe(&self, vid: ty::TyVid) -> Option<Ty<'tcx>> {
+    pub fn root_var(&mut self, vid: ty::TyVid) -> ty::TyVid {
+        self.eq_relations.find(vid)
+    }
+
+    pub fn probe(&mut self, vid: ty::TyVid) -> Option<Ty<'tcx>> {
+        let vid = self.root_var(vid);
+        self.probe_root(vid)
+    }
+
+    /// Retrieves the type of `vid` given that it is currently a root in the unification table
+    pub fn probe_root(&mut self, vid: ty::TyVid) -> Option<Ty<'tcx>> {
+        debug_assert!(self.root_var(vid) == vid);
         match self.values.get(vid.index as usize).value {
             Bounded { .. } => None,
             Known(t) => Some(t)
         }
     }
 
-    pub fn replace_if_possible(&self, t: Ty<'tcx>) -> Ty<'tcx> {
+    pub fn replace_if_possible(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
         match t.sty {
             ty::TyInfer(ty::TyVar(v)) => {
                 match self.probe(v) {
@@ -167,18 +212,23 @@ pub fn replace_if_possible(&self, t: Ty<'tcx>) -> Ty<'tcx> {
     }
 
     pub fn snapshot(&mut self) -> Snapshot {
-        Snapshot { snapshot: self.values.start_snapshot() }
+        Snapshot {
+            snapshot: self.values.start_snapshot(),
+            eq_snapshot: self.eq_relations.snapshot(),
+        }
     }
 
     pub fn rollback_to(&mut self, s: Snapshot) {
         self.values.rollback_to(s.snapshot);
+        self.eq_relations.rollback_to(s.eq_snapshot);
     }
 
     pub fn commit(&mut self, s: Snapshot) {
         self.values.commit(s.snapshot);
+        self.eq_relations.commit(s.eq_snapshot);
     }
 
-    pub fn types_escaping_snapshot(&self, s: &Snapshot) -> Vec<Ty<'tcx>> {
+    pub fn types_escaping_snapshot(&mut self, s: &Snapshot) -> Vec<Ty<'tcx>> {
         /*!
          * Find the set of type variables that existed *before* `s`
          * but which have only been unified since `s` started, and
@@ -208,7 +258,10 @@ pub fn types_escaping_snapshot(&self, s: &Snapshot) -> Vec<Ty<'tcx>> {
                     if vid.index < new_elem_threshold {
                         // quick check to see if this variable was
                         // created since the snapshot started or not.
-                        let escaping_type = self.probe(vid).unwrap();
+                        let escaping_type = match self.values.get(vid.index as usize).value {
+                            Bounded { .. } => unreachable!(),
+                            Known(ty) => ty,
+                        };
                         escaping_types.push(escaping_type);
                     }
                     debug!("SpecifyVar({:?}) new_elem_threshold={}", vid, new_elem_threshold);
@@ -221,13 +274,15 @@ pub fn types_escaping_snapshot(&self, s: &Snapshot) -> Vec<Ty<'tcx>> {
         escaping_types
     }
 
-    pub fn unsolved_variables(&self) -> Vec<ty::TyVid> {
-        self.values
-            .iter()
-            .enumerate()
-            .filter_map(|(i, value)| match &value.value {
-                &TypeVariableValue::Known(_) => None,
-                &TypeVariableValue::Bounded { .. } => Some(ty::TyVid { index: i as u32 })
+    pub fn unsolved_variables(&mut self) -> Vec<ty::TyVid> {
+        (0..self.values.len())
+            .filter_map(|i| {
+                let vid = ty::TyVid { index: i as u32 };
+                if self.probe(vid).is_some() {
+                    None
+                } else {
+                    Some(vid)
+                }
             })
             .collect()
     }
@@ -250,6 +305,13 @@ fn reverse(values: &mut Vec<TypeVariableData<'tcx>>, action: UndoEntry<'tcx>) {
                 relations(&mut (*values)[a.index as usize]).pop();
                 relations(&mut (*values)[b.index as usize]).pop();
             }
+
+            RelateRange(i, n) => {
+                let relations = relations(&mut (*values)[i.index as usize]);
+                for _ in 0..n {
+                    relations.pop();
+                }
+            }
         }
     }
 }
index 5008a92a4f59d7facdfdac69b5acb3e70b0283b1..3f8c3fbce047a24fab1f5ce04bae107fc21cba2c 100644 (file)
@@ -73,3 +73,10 @@ fn to_type(&self, tcx: &TyCtxt<'tcx>) -> Ty<'tcx> {
         tcx.mk_mach_float(*self)
     }
 }
+
+impl UnifyKey for ty::TyVid {
+    type Value = ();
+    fn index(&self) -> u32 { self.index }
+    fn from_index(i: u32) -> ty::TyVid { ty::TyVid { index: i } }
+    fn tag(_: Option<ty::TyVid>) -> &'static str { "TyVid" }
+}
index 7a1ac830b22939d79170caee1d58ac1b5a574c94..3feea3218d0138b95100fe70211841692509f0a7 100644 (file)
@@ -211,7 +211,7 @@ fn set(&mut self, key: K, new_value: VarValue<K>) {
     /// really more of a building block. If the values associated with
     /// your key are non-trivial, you would probably prefer to call
     /// `unify_var_var` below.
-    fn unify(&mut self, root_a: VarValue<K>, root_b: VarValue<K>, new_value: K::Value) {
+    fn unify(&mut self, root_a: VarValue<K>, root_b: VarValue<K>, new_value: K::Value) -> K {
         debug!("unify(root_a(id={:?}, rank={:?}), root_b(id={:?}, rank={:?}))",
                root_a.key(),
                root_a.rank,
@@ -221,14 +221,14 @@ fn unify(&mut self, root_a: VarValue<K>, root_b: VarValue<K>, new_value: K::Valu
         if root_a.rank > root_b.rank {
             // a has greater rank, so a should become b's parent,
             // i.e., b should redirect to a.
-            self.redirect_root(root_a.rank, root_b, root_a, new_value);
+            self.redirect_root(root_a.rank, root_b, root_a, new_value)
         } else if root_a.rank < root_b.rank {
             // b has greater rank, so a should redirect to b.
-            self.redirect_root(root_b.rank, root_a, root_b, new_value);
+            self.redirect_root(root_b.rank, root_a, root_b, new_value)
         } else {
             // If equal, redirect one to the other and increment the
             // other's rank.
-            self.redirect_root(root_a.rank + 1, root_a, root_b, new_value);
+            self.redirect_root(root_a.rank + 1, root_a, root_b, new_value)
         }
     }
 
@@ -236,11 +236,12 @@ fn redirect_root(&mut self,
                      new_rank: u32,
                      old_root: VarValue<K>,
                      new_root: VarValue<K>,
-                     new_value: K::Value) {
+                     new_value: K::Value) -> K {
         let old_root_key = old_root.key();
         let new_root_key = new_root.key();
         self.set(old_root_key, old_root.redirect(new_root_key));
         self.set(new_root_key, new_root.root(new_rank, new_value));
+        new_root_key
     }
 }
 
@@ -256,14 +257,16 @@ fn reverse(_: &mut Vec<VarValue<K>>, _: ()) {}
 impl<'tcx, K: UnifyKey> UnificationTable<K>
     where K::Value: Combine
 {
-    pub fn union(&mut self, a_id: K, b_id: K) {
+    pub fn union(&mut self, a_id: K, b_id: K) -> K {
         let node_a = self.get(a_id);
         let node_b = self.get(b_id);
         let a_id = node_a.key();
         let b_id = node_b.key();
         if a_id != b_id {
             let new_value = node_a.value.combine(&node_b.value);
-            self.unify(node_a, node_b, new_value);
+            self.unify(node_a, node_b, new_value)
+        } else {
+            a_id
         }
     }
 
@@ -290,14 +293,14 @@ impl<'tcx, K, V> UnificationTable<K>
     where K: UnifyKey<Value = Option<V>>,
           V: Clone + PartialEq + Debug
 {
-    pub fn unify_var_var(&mut self, a_id: K, b_id: K) -> Result<(), (V, V)> {
+    pub fn unify_var_var(&mut self, a_id: K, b_id: K) -> Result<K, (V, V)> {
         let node_a = self.get(a_id);
         let node_b = self.get(b_id);
         let a_id = node_a.key();
         let b_id = node_b.key();
 
         if a_id == b_id {
-            return Ok(());
+            return Ok(a_id);
         }
 
         let combined = {
diff --git a/src/test/run-pass/bench/issue-32062.rs b/src/test/run-pass/bench/issue-32062.rs
new file mode 100644 (file)
index 0000000..8f6457d
--- /dev/null
@@ -0,0 +1,58 @@
+// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// pretty-expanded FIXME #23616
+
+fn main() {
+    let _ = test(Some(0).into_iter());
+}
+
+trait Parser {
+    type Input: Iterator;
+    type Output;
+    fn parse(self, input: Self::Input) -> Result<(Self::Output, Self::Input), ()>;
+    fn chain<P>(self, p: P) -> Chain<Self, P> where Self: Sized {
+        Chain(self, p)
+    }
+}
+
+struct Token<T>(T::Item) where T: Iterator;
+
+impl<T> Parser for Token<T> where T: Iterator {
+    type Input = T;
+    type Output = T::Item;
+    fn parse(self, _input: Self::Input) -> Result<(Self::Output, Self::Input), ()> {
+        Err(())
+    }
+}
+
+struct Chain<L, R>(L, R);
+
+impl<L, R> Parser for Chain<L, R> where L: Parser, R: Parser<Input = L::Input> {
+    type Input = L::Input;
+    type Output = (L::Output, R::Output);
+    fn parse(self, _input: Self::Input) -> Result<(Self::Output, Self::Input), ()> {
+        Err(())
+    }
+}
+
+fn test<I>(i: I) -> Result<((), I), ()> where I: Iterator<Item = i32> {
+    Chain(Token(0), Token(1))
+        .chain(Chain(Token(0), Token(1)))
+        .chain(Chain(Token(0), Token(1)))
+        .chain(Chain(Token(0), Token(1)))
+        .chain(Chain(Token(0), Token(1)))
+        .chain(Chain(Token(0), Token(1)))
+        .chain(Chain(Token(0), Token(1)))
+        .chain(Chain(Token(0), Token(1)))
+        .chain(Chain(Token(0), Token(1)))
+        .parse(i)
+        .map(|(_, i)| ((), i))
+}