]> git.lizzy.rs Git - rust.git/commitdiff
Extract unification code to unify module
authorFlorian Diebold <flodiebold@gmail.com>
Sun, 1 Dec 2019 19:30:28 +0000 (20:30 +0100)
committerFlorian Diebold <flodiebold@gmail.com>
Mon, 2 Dec 2019 18:33:13 +0000 (19:33 +0100)
crates/ra_hir_ty/src/infer.rs
crates/ra_hir_ty/src/infer/coerce.rs
crates/ra_hir_ty/src/infer/expr.rs
crates/ra_hir_ty/src/infer/pat.rs
crates/ra_hir_ty/src/infer/path.rs
crates/ra_hir_ty/src/infer/unify.rs

index fe259371f5f9fbc54c48ba6acb74ba34b4cb03aa..81afbd2b47f7c5031df7461ca7ca133e5e5099cb 100644 (file)
@@ -18,7 +18,6 @@
 use std::ops::Index;
 use std::sync::Arc;
 
-use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue};
 use rustc_hash::FxHashMap;
 
 use hir_def::{
 use hir_expand::{diagnostics::DiagnosticSink, name};
 use ra_arena::map::ArenaMap;
 use ra_prof::profile;
-use test_utils::tested_by;
 
 use super::{
     primitive::{FloatTy, IntTy},
     traits::{Guidance, Obligation, ProjectionPredicate, Solution},
-    ApplicationTy, InEnvironment, ProjectionTy, Substs, TraitEnvironment, TraitRef, Ty, TypeCtor,
+    ApplicationTy, InEnvironment, ProjectionTy, TraitEnvironment, TraitRef, Ty, TypeCtor,
     TypeWalk, Uncertain,
 };
 use crate::{db::HirDatabase, infer::diagnostics::InferenceDiagnostic};
@@ -191,7 +189,7 @@ struct InferenceContext<'a, D: HirDatabase> {
     owner: DefWithBodyId,
     body: Arc<Body>,
     resolver: Resolver,
-    var_unification_table: InPlaceUnificationTable<TypeVarId>,
+    table: unify::InferenceTable,
     trait_env: Arc<TraitEnvironment>,
     obligations: Vec<Obligation>,
     result: InferenceResult,
@@ -209,7 +207,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
     fn new(db: &'a D, owner: DefWithBodyId, resolver: Resolver) -> Self {
         InferenceContext {
             result: InferenceResult::default(),
-            var_unification_table: InPlaceUnificationTable::new(),
+            table: unify::InferenceTable::new(),
             obligations: Vec::default(),
             return_ty: Ty::Unknown, // set in collect_fn_signature
             trait_env: TraitEnvironment::lower(db, &resolver),
@@ -224,13 +222,12 @@ fn new(db: &'a D, owner: DefWithBodyId, resolver: Resolver) -> Self {
     fn resolve_all(mut self) -> InferenceResult {
         // FIXME resolve obligations as well (use Guidance if necessary)
         let mut result = mem::replace(&mut self.result, InferenceResult::default());
-        let mut tv_stack = Vec::new();
         for ty in result.type_of_expr.values_mut() {
-            let resolved = self.resolve_ty_completely(&mut tv_stack, mem::replace(ty, Ty::Unknown));
+            let resolved = self.table.resolve_ty_completely(mem::replace(ty, Ty::Unknown));
             *ty = resolved;
         }
         for ty in result.type_of_pat.values_mut() {
-            let resolved = self.resolve_ty_completely(&mut tv_stack, mem::replace(ty, Ty::Unknown));
+            let resolved = self.table.resolve_ty_completely(mem::replace(ty, Ty::Unknown));
             *ty = resolved;
         }
         result
@@ -275,96 +272,15 @@ fn make_ty(&mut self, type_ref: &TypeRef) -> Ty {
         self.normalize_associated_types_in(ty)
     }
 
-    fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs, depth: usize) -> bool {
-        substs1.0.iter().zip(substs2.0.iter()).all(|(t1, t2)| self.unify_inner(t1, t2, depth))
-    }
-
-    fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
-        self.unify_inner(ty1, ty2, 0)
-    }
-
-    fn unify_inner(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool {
-        if depth > 1000 {
-            // prevent stackoverflows
-            panic!("infinite recursion in unification");
-        }
-        if ty1 == ty2 {
-            return true;
-        }
-        // try to resolve type vars first
-        let ty1 = self.resolve_ty_shallow(ty1);
-        let ty2 = self.resolve_ty_shallow(ty2);
-        match (&*ty1, &*ty2) {
-            (Ty::Apply(a_ty1), Ty::Apply(a_ty2)) if a_ty1.ctor == a_ty2.ctor => {
-                self.unify_substs(&a_ty1.parameters, &a_ty2.parameters, depth + 1)
-            }
-            _ => self.unify_inner_trivial(&ty1, &ty2),
-        }
-    }
-
-    fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
-        match (ty1, ty2) {
-            (Ty::Unknown, _) | (_, Ty::Unknown) => true,
-
-            (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2)))
-            | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2)))
-            | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2)))
-            | (
-                Ty::Infer(InferTy::MaybeNeverTypeVar(tv1)),
-                Ty::Infer(InferTy::MaybeNeverTypeVar(tv2)),
-            ) => {
-                // both type vars are unknown since we tried to resolve them
-                self.var_unification_table.union(*tv1, *tv2);
-                true
-            }
-
-            // The order of MaybeNeverTypeVar matters here.
-            // Unifying MaybeNeverTypeVar and TypeVar will let the latter become MaybeNeverTypeVar.
-            // Unifying MaybeNeverTypeVar and other concrete type will let the former become it.
-            (Ty::Infer(InferTy::TypeVar(tv)), other)
-            | (other, Ty::Infer(InferTy::TypeVar(tv)))
-            | (Ty::Infer(InferTy::MaybeNeverTypeVar(tv)), other)
-            | (other, Ty::Infer(InferTy::MaybeNeverTypeVar(tv)))
-            | (Ty::Infer(InferTy::IntVar(tv)), other @ ty_app!(TypeCtor::Int(_)))
-            | (other @ ty_app!(TypeCtor::Int(_)), Ty::Infer(InferTy::IntVar(tv)))
-            | (Ty::Infer(InferTy::FloatVar(tv)), other @ ty_app!(TypeCtor::Float(_)))
-            | (other @ ty_app!(TypeCtor::Float(_)), Ty::Infer(InferTy::FloatVar(tv))) => {
-                // the type var is unknown since we tried to resolve it
-                self.var_unification_table.union_value(*tv, TypeVarValue::Known(other.clone()));
-                true
-            }
-
-            _ => false,
-        }
-    }
-
-    fn new_type_var(&mut self) -> Ty {
-        Ty::Infer(InferTy::TypeVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
-    }
-
-    fn new_integer_var(&mut self) -> Ty {
-        Ty::Infer(InferTy::IntVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
-    }
-
-    fn new_float_var(&mut self) -> Ty {
-        Ty::Infer(InferTy::FloatVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
-    }
-
-    fn new_maybe_never_type_var(&mut self) -> Ty {
-        Ty::Infer(InferTy::MaybeNeverTypeVar(
-            self.var_unification_table.new_key(TypeVarValue::Unknown),
-        ))
-    }
-
     /// Replaces Ty::Unknown by a new type var, so we can maybe still infer it.
     fn insert_type_vars_shallow(&mut self, ty: Ty) -> Ty {
         match ty {
-            Ty::Unknown => self.new_type_var(),
+            Ty::Unknown => self.table.new_type_var(),
             Ty::Apply(ApplicationTy { ctor: TypeCtor::Int(Uncertain::Unknown), .. }) => {
-                self.new_integer_var()
+                self.table.new_integer_var()
             }
             Ty::Apply(ApplicationTy { ctor: TypeCtor::Float(Uncertain::Unknown), .. }) => {
-                self.new_float_var()
+                self.table.new_float_var()
             }
             _ => ty,
         }
@@ -402,64 +318,22 @@ fn resolve_obligations_as_possible(&mut self) {
         }
     }
 
+    fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
+        self.table.unify(ty1, ty2)
+    }
+
     /// Resolves the type as far as currently possible, replacing type variables
     /// by their known types. All types returned by the infer_* functions should
     /// be resolved as far as possible, i.e. contain no type variables with
     /// known type.
-    fn resolve_ty_as_possible(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
+    fn resolve_ty_as_possible(&mut self, ty: Ty) -> Ty {
         self.resolve_obligations_as_possible();
 
-        ty.fold(&mut |ty| match ty {
-            Ty::Infer(tv) => {
-                let inner = tv.to_inner();
-                if tv_stack.contains(&inner) {
-                    tested_by!(type_var_cycles_resolve_as_possible);
-                    // recursive type
-                    return tv.fallback_value();
-                }
-                if let Some(known_ty) =
-                    self.var_unification_table.inlined_probe_value(inner).known()
-                {
-                    // known_ty may contain other variables that are known by now
-                    tv_stack.push(inner);
-                    let result = self.resolve_ty_as_possible(tv_stack, known_ty.clone());
-                    tv_stack.pop();
-                    result
-                } else {
-                    ty
-                }
-            }
-            _ => ty,
-        })
+        self.table.resolve_ty_as_possible(ty)
     }
 
-    /// If `ty` is a type variable with known type, returns that type;
-    /// otherwise, return ty.
     fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> {
-        let mut ty = Cow::Borrowed(ty);
-        // The type variable could resolve to a int/float variable. Hence try
-        // resolving up to three times; each type of variable shouldn't occur
-        // more than once
-        for i in 0..3 {
-            if i > 0 {
-                tested_by!(type_var_resolves_to_int_var);
-            }
-            match &*ty {
-                Ty::Infer(tv) => {
-                    let inner = tv.to_inner();
-                    match self.var_unification_table.inlined_probe_value(inner).known() {
-                        Some(known_ty) => {
-                            // The known_ty can't be a type var itself
-                            ty = Cow::Owned(known_ty.clone());
-                        }
-                        _ => return ty,
-                    }
-                }
-                _ => return ty,
-            }
-        }
-        log::error!("Inference variable still not resolved: {:?}", ty);
-        ty
+        self.table.resolve_ty_shallow(ty)
     }
 
     /// Recurses through the given type, normalizing associated types mentioned
@@ -469,7 +343,7 @@ fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> {
     /// call). `make_ty` handles this already, but e.g. for field types we need
     /// to do it as well.
     fn normalize_associated_types_in(&mut self, ty: Ty) -> Ty {
-        let ty = self.resolve_ty_as_possible(&mut vec![], ty);
+        let ty = self.resolve_ty_as_possible(ty);
         ty.fold(&mut |ty| match ty {
             Ty::Projection(proj_ty) => self.normalize_projection_ty(proj_ty),
             _ => ty,
@@ -477,40 +351,13 @@ fn normalize_associated_types_in(&mut self, ty: Ty) -> Ty {
     }
 
     fn normalize_projection_ty(&mut self, proj_ty: ProjectionTy) -> Ty {
-        let var = self.new_type_var();
+        let var = self.table.new_type_var();
         let predicate = ProjectionPredicate { projection_ty: proj_ty, ty: var.clone() };
         let obligation = Obligation::Projection(predicate);
         self.obligations.push(obligation);
         var
     }
 
-    /// Resolves the type completely; type variables without known type are
-    /// replaced by Ty::Unknown.
-    fn resolve_ty_completely(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
-        ty.fold(&mut |ty| match ty {
-            Ty::Infer(tv) => {
-                let inner = tv.to_inner();
-                if tv_stack.contains(&inner) {
-                    tested_by!(type_var_cycles_resolve_completely);
-                    // recursive type
-                    return tv.fallback_value();
-                }
-                if let Some(known_ty) =
-                    self.var_unification_table.inlined_probe_value(inner).known()
-                {
-                    // known_ty may contain other variables that are known by now
-                    tv_stack.push(inner);
-                    let result = self.resolve_ty_completely(tv_stack, known_ty.clone());
-                    tv_stack.pop();
-                    result
-                } else {
-                    tv.fallback_value()
-                }
-            }
-            _ => ty,
-        })
-    }
-
     fn resolve_variant(&mut self, path: Option<&Path>) -> (Ty, Option<VariantId>) {
         let path = match path {
             Some(path) => path,
@@ -615,78 +462,20 @@ fn resolve_range_to_inclusive(&self) -> Option<AdtId> {
     }
 }
 
-/// The ID of a type variable.
-#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
-pub struct TypeVarId(pub(super) u32);
-
-impl UnifyKey for TypeVarId {
-    type Value = TypeVarValue;
-
-    fn index(&self) -> u32 {
-        self.0
-    }
-
-    fn from_index(i: u32) -> Self {
-        TypeVarId(i)
-    }
-
-    fn tag() -> &'static str {
-        "TypeVarId"
-    }
-}
-
-/// The value of a type variable: either we already know the type, or we don't
-/// know it yet.
-#[derive(Clone, PartialEq, Eq, Debug)]
-pub enum TypeVarValue {
-    Known(Ty),
-    Unknown,
-}
-
-impl TypeVarValue {
-    fn known(&self) -> Option<&Ty> {
-        match self {
-            TypeVarValue::Known(ty) => Some(ty),
-            TypeVarValue::Unknown => None,
-        }
-    }
-}
-
-impl UnifyValue for TypeVarValue {
-    type Error = NoError;
-
-    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, NoError> {
-        match (value1, value2) {
-            // We should never equate two type variables, both of which have
-            // known types. Instead, we recursively equate those types.
-            (TypeVarValue::Known(t1), TypeVarValue::Known(t2)) => panic!(
-                "equating two type variables, both of which have known types: {:?} and {:?}",
-                t1, t2
-            ),
-
-            // If one side is known, prefer that one.
-            (TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()),
-            (TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()),
-
-            (TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown),
-        }
-    }
-}
-
 /// The kinds of placeholders we need during type inference. There's separate
 /// values for general types, and for integer and float variables. The latter
 /// two are used for inference of literal values (e.g. `100` could be one of
 /// several integer types).
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 pub enum InferTy {
-    TypeVar(TypeVarId),
-    IntVar(TypeVarId),
-    FloatVar(TypeVarId),
-    MaybeNeverTypeVar(TypeVarId),
+    TypeVar(unify::TypeVarId),
+    IntVar(unify::TypeVarId),
+    FloatVar(unify::TypeVarId),
+    MaybeNeverTypeVar(unify::TypeVarId),
 }
 
 impl InferTy {
-    fn to_inner(self) -> TypeVarId {
+    fn to_inner(self) -> unify::TypeVarId {
         match self {
             InferTy::TypeVar(ty)
             | InferTy::IntVar(ty)
index 064993d34fac0e47877ce91d681e53e4d0c53332..9bfc701cd56861ee9225a17bd872392d12b8f72b 100644 (file)
@@ -10,7 +10,7 @@
 
 use crate::{autoderef, db::HirDatabase, Substs, Ty, TypeCtor, TypeWalk};
 
-use super::{InEnvironment, InferTy, InferenceContext, TypeVarValue};
+use super::{InEnvironment, InferTy, InferenceContext, unify::TypeVarValue};
 
 impl<'a, D: HirDatabase> InferenceContext<'a, D> {
     /// Unify two types, but may coerce the first one to the second one
@@ -85,8 +85,8 @@ fn coerce_inner(&mut self, mut from_ty: Ty, to_ty: &Ty) -> bool {
         match (&from_ty, to_ty) {
             // Never type will make type variable to fallback to Never Type instead of Unknown.
             (ty_app!(TypeCtor::Never), Ty::Infer(InferTy::TypeVar(tv))) => {
-                let var = self.new_maybe_never_type_var();
-                self.var_unification_table.union_value(*tv, TypeVarValue::Known(var));
+                let var = self.table.new_maybe_never_type_var();
+                self.table.var_unification_table.union_value(*tv, TypeVarValue::Known(var));
                 return true;
             }
             (ty_app!(TypeCtor::Never), _) => return true,
@@ -94,7 +94,7 @@ fn coerce_inner(&mut self, mut from_ty: Ty, to_ty: &Ty) -> bool {
             // Trivial cases, this should go after `never` check to
             // avoid infer result type to be never
             _ => {
-                if self.unify_inner_trivial(&from_ty, &to_ty) {
+                if self.table.unify_inner_trivial(&from_ty, &to_ty) {
                     return true;
                 }
             }
@@ -330,7 +330,7 @@ fn unify_autoderef_behind_ref(&mut self, from_ty: &Ty, to_ty: &Ty) -> bool {
                 // Stop when constructor matches.
                 (ty_app!(from_ctor, st1), ty_app!(to_ctor, st2)) if from_ctor == to_ctor => {
                     // It will not recurse to `coerce`.
-                    return self.unify_substs(st1, st2, 0);
+                    return self.table.unify_substs(st1, st2, 0);
                 }
                 _ => {}
             }
index 4014f4732d483bb07a88a541f9cd34ed0c2034d4..1e78f6efd46f7da85329321f0ced3036a6ac3f1e 100644 (file)
@@ -32,7 +32,7 @@ pub(super) fn infer_expr(&mut self, tgt_expr: ExprId, expected: &Expectation) ->
                 TypeMismatch { expected: expected.ty.clone(), actual: ty.clone() },
             );
         }
-        let ty = self.resolve_ty_as_possible(&mut vec![], ty);
+        let ty = self.resolve_ty_as_possible(ty);
         ty
     }
 
@@ -53,7 +53,7 @@ fn infer_expr_coerce(&mut self, expr: ExprId, expected: &Expectation) -> Ty {
             expected.ty.clone()
         };
 
-        self.resolve_ty_as_possible(&mut vec![], ty)
+        self.resolve_ty_as_possible(ty)
     }
 
     fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
@@ -94,7 +94,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
 
                 let pat_ty = match self.resolve_into_iter_item() {
                     Some(into_iter_item_alias) => {
-                        let pat_ty = self.new_type_var();
+                        let pat_ty = self.table.new_type_var();
                         let projection = ProjectionPredicate {
                             ty: pat_ty.clone(),
                             projection_ty: ProjectionTy {
@@ -103,7 +103,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                             },
                         };
                         self.obligations.push(Obligation::Projection(projection));
-                        self.resolve_ty_as_possible(&mut vec![], pat_ty)
+                        self.resolve_ty_as_possible(pat_ty)
                     }
                     None => Ty::Unknown,
                 };
@@ -128,7 +128,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                 }
 
                 // add return type
-                let ret_ty = self.new_type_var();
+                let ret_ty = self.table.new_type_var();
                 sig_tys.push(ret_ty.clone());
                 let sig_ty = Ty::apply(
                     TypeCtor::FnPtr { num_args: sig_tys.len() as u16 - 1 },
@@ -167,7 +167,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
             Expr::Match { expr, arms } => {
                 let input_ty = self.infer_expr(*expr, &Expectation::none());
 
-                let mut result_ty = self.new_maybe_never_type_var();
+                let mut result_ty = self.table.new_maybe_never_type_var();
 
                 for arm in arms {
                     for &pat in &arm.pats {
@@ -283,7 +283,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                 let inner_ty = self.infer_expr(*expr, &Expectation::none());
                 let ty = match self.resolve_future_future_output() {
                     Some(future_future_output_alias) => {
-                        let ty = self.new_type_var();
+                        let ty = self.table.new_type_var();
                         let projection = ProjectionPredicate {
                             ty: ty.clone(),
                             projection_ty: ProjectionTy {
@@ -292,7 +292,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                             },
                         };
                         self.obligations.push(Obligation::Projection(projection));
-                        self.resolve_ty_as_possible(&mut vec![], ty)
+                        self.resolve_ty_as_possible(ty)
                     }
                     None => Ty::Unknown,
                 };
@@ -302,7 +302,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                 let inner_ty = self.infer_expr(*expr, &Expectation::none());
                 let ty = match self.resolve_ops_try_ok() {
                     Some(ops_try_ok_alias) => {
-                        let ty = self.new_type_var();
+                        let ty = self.table.new_type_var();
                         let projection = ProjectionPredicate {
                             ty: ty.clone(),
                             projection_ty: ProjectionTy {
@@ -311,7 +311,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                             },
                         };
                         self.obligations.push(Obligation::Projection(projection));
-                        self.resolve_ty_as_possible(&mut vec![], ty)
+                        self.resolve_ty_as_possible(ty)
                     }
                     None => Ty::Unknown,
                 };
@@ -465,10 +465,10 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                     ty_app!(TypeCtor::Tuple { .. }, st) => st
                         .iter()
                         .cloned()
-                        .chain(repeat_with(|| self.new_type_var()))
+                        .chain(repeat_with(|| self.table.new_type_var()))
                         .take(exprs.len())
                         .collect::<Vec<_>>(),
-                    _ => (0..exprs.len()).map(|_| self.new_type_var()).collect(),
+                    _ => (0..exprs.len()).map(|_| self.table.new_type_var()).collect(),
                 };
 
                 for (expr, ty) in exprs.iter().zip(tys.iter_mut()) {
@@ -482,7 +482,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                     ty_app!(TypeCtor::Array, st) | ty_app!(TypeCtor::Slice, st) => {
                         st.as_single().clone()
                     }
-                    _ => self.new_type_var(),
+                    _ => self.table.new_type_var(),
                 };
 
                 match array {
@@ -524,7 +524,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
         };
         // use a new type variable if we got Ty::Unknown here
         let ty = self.insert_type_vars_shallow(ty);
-        let ty = self.resolve_ty_as_possible(&mut vec![], ty);
+        let ty = self.resolve_ty_as_possible(ty);
         self.write_expr_ty(tgt_expr, ty.clone());
         ty
     }
@@ -553,7 +553,7 @@ fn infer_block(
                         }
                     }
 
-                    let ty = self.resolve_ty_as_possible(&mut vec![], ty);
+                    let ty = self.resolve_ty_as_possible(ty);
                     self.infer_pat(*pat, &ty, BindingMode::default());
                 }
                 Statement::Expr(expr) => {
index 1ebb362399a9d9cdac14a01b4fc50f388f19d9cb..a14662884458c9fe05307b582f2200c06d863b0c 100644 (file)
@@ -170,7 +170,7 @@ pub(super) fn infer_pat(
                     }
                     BindingMode::Move => inner_ty.clone(),
                 };
-                let bound_ty = self.resolve_ty_as_possible(&mut vec![], bound_ty);
+                let bound_ty = self.resolve_ty_as_possible(bound_ty);
                 self.write_pat_ty(pat, bound_ty);
                 return inner_ty;
             }
@@ -179,7 +179,7 @@ pub(super) fn infer_pat(
         // use a new type variable if we got Ty::Unknown here
         let ty = self.insert_type_vars_shallow(ty);
         self.unify(&ty, expected);
-        let ty = self.resolve_ty_as_possible(&mut vec![], ty);
+        let ty = self.resolve_ty_as_possible(ty);
         self.write_pat_ty(pat, ty.clone());
         ty
     }
index bbf146418e14fdf6ff7ae78eaefa2e2d4469773e..d0d7646a49d3345d3f301c9619df0450f04eabd6 100644 (file)
@@ -57,7 +57,7 @@ fn resolve_value_path(
         let typable: ValueTyDefId = match value {
             ValueNs::LocalBinding(pat) => {
                 let ty = self.result.type_of_pat.get(pat)?.clone();
-                let ty = self.resolve_ty_as_possible(&mut vec![], ty);
+                let ty = self.resolve_ty_as_possible(ty);
                 return Some(ty);
             }
             ValueNs::FunctionId(it) => it.into(),
@@ -211,7 +211,7 @@ fn resolve_ty_assoc_item(
                         // we're picking this method
                         let trait_substs = Substs::build_for_def(self.db, trait_)
                             .push(ty.clone())
-                            .fill(std::iter::repeat_with(|| self.new_type_var()))
+                            .fill(std::iter::repeat_with(|| self.table.new_type_var()))
                             .build();
                         let substs = Substs::build_for_def(self.db, item)
                             .use_parent_substs(&trait_substs)
index f3a8756785d852fe87b64970eed8247c90df68e3..ff50138f5d990a01ec21b70df73c6824817aca0c 100644 (file)
@@ -1,9 +1,15 @@
 //! Unification and canonicalization logic.
 
+use std::borrow::Cow;
+
+use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue};
+
+use test_utils::tested_by;
+
 use super::{InferenceContext, Obligation};
 use crate::{
     db::HirDatabase, utils::make_mut_slice, Canonical, InEnvironment, InferTy, ProjectionPredicate,
-    ProjectionTy, Substs, TraitRef, Ty, TypeWalk,
+    ProjectionTy, Substs, TraitRef, Ty, TypeCtor, TypeWalk,
 };
 
 impl<'a, D: HirDatabase> InferenceContext<'a, D> {
@@ -24,7 +30,7 @@ pub(super) struct Canonicalizer<'a, 'b, D: HirDatabase>
     /// A stack of type variables that is used to detect recursive types (which
     /// are an error, but we need to protect against them to avoid stack
     /// overflows).
-    var_stack: Vec<super::TypeVarId>,
+    var_stack: Vec<TypeVarId>,
 }
 
 pub(super) struct Canonicalized<T> {
@@ -53,14 +59,14 @@ fn do_canonicalize_ty(&mut self, ty: Ty) -> Ty {
                     return tv.fallback_value();
                 }
                 if let Some(known_ty) =
-                    self.ctx.var_unification_table.inlined_probe_value(inner).known()
+                    self.ctx.table.var_unification_table.inlined_probe_value(inner).known()
                 {
                     self.var_stack.push(inner);
                     let result = self.do_canonicalize_ty(known_ty.clone());
                     self.var_stack.pop();
                     result
                 } else {
-                    let root = self.ctx.var_unification_table.find(inner);
+                    let root = self.ctx.table.var_unification_table.find(inner);
                     let free_var = match tv {
                         InferTy::TypeVar(_) => InferTy::TypeVar(root),
                         InferTy::IntVar(_) => InferTy::IntVar(root),
@@ -153,10 +159,264 @@ pub fn apply_solution(
         solution: Canonical<Vec<Ty>>,
     ) {
         // the solution may contain new variables, which we need to convert to new inference vars
-        let new_vars = Substs((0..solution.num_vars).map(|_| ctx.new_type_var()).collect());
+        let new_vars = Substs((0..solution.num_vars).map(|_| ctx.table.new_type_var()).collect());
         for (i, ty) in solution.value.into_iter().enumerate() {
             let var = self.free_vars[i];
-            ctx.unify(&Ty::Infer(var), &ty.subst_bound_vars(&new_vars));
+            ctx.table.unify(&Ty::Infer(var), &ty.subst_bound_vars(&new_vars));
+        }
+    }
+}
+
+pub fn unify(ty1: Canonical<&Ty>, ty2: &Ty) -> Substs {
+    let mut table = InferenceTable::new();
+    let vars = Substs::builder(ty1.num_vars)
+        .fill(std::iter::repeat_with(|| table.new_type_var())).build();
+    let ty_with_vars = ty1.value.clone().subst_bound_vars(&vars);
+    table.unify(&ty_with_vars, ty2);
+    Substs::builder(ty1.num_vars).fill(vars.iter().map(|v| table.resolve_ty_completely(v.clone()))).build()
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct InferenceTable {
+    pub(super) var_unification_table: InPlaceUnificationTable<TypeVarId>,
+}
+
+impl InferenceTable {
+    pub fn new() -> Self {
+        InferenceTable {
+            var_unification_table: InPlaceUnificationTable::new(),
+        }
+    }
+
+    pub fn new_type_var(&mut self) -> Ty {
+        Ty::Infer(InferTy::TypeVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
+    }
+
+    pub fn new_integer_var(&mut self) -> Ty {
+        Ty::Infer(InferTy::IntVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
+    }
+
+    pub fn new_float_var(&mut self) -> Ty {
+        Ty::Infer(InferTy::FloatVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
+    }
+
+    pub fn new_maybe_never_type_var(&mut self) -> Ty {
+        Ty::Infer(InferTy::MaybeNeverTypeVar(
+            self.var_unification_table.new_key(TypeVarValue::Unknown),
+        ))
+    }
+
+    pub fn resolve_ty_completely(&mut self, ty: Ty) -> Ty {
+        self.resolve_ty_completely_inner(&mut Vec::new(), ty)
+    }
+
+    pub fn resolve_ty_as_possible(&mut self, ty: Ty) -> Ty {
+        self.resolve_ty_as_possible_inner(&mut Vec::new(), ty)
+    }
+
+    pub fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
+        self.unify_inner(ty1, ty2, 0)
+    }
+
+    pub fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs, depth: usize) -> bool {
+        substs1.0.iter().zip(substs2.0.iter()).all(|(t1, t2)| self.unify_inner(t1, t2, depth))
+    }
+
+    fn unify_inner(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool {
+        if depth > 1000 {
+            // prevent stackoverflows
+            panic!("infinite recursion in unification");
+        }
+        if ty1 == ty2 {
+            return true;
+        }
+        // try to resolve type vars first
+        let ty1 = self.resolve_ty_shallow(ty1);
+        let ty2 = self.resolve_ty_shallow(ty2);
+        match (&*ty1, &*ty2) {
+            (Ty::Apply(a_ty1), Ty::Apply(a_ty2)) if a_ty1.ctor == a_ty2.ctor => {
+                self.unify_substs(&a_ty1.parameters, &a_ty2.parameters, depth + 1)
+            }
+            _ => self.unify_inner_trivial(&ty1, &ty2),
+        }
+    }
+
+    pub(super) fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
+        match (ty1, ty2) {
+            (Ty::Unknown, _) | (_, Ty::Unknown) => true,
+
+            (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2)))
+            | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2)))
+            | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2)))
+            | (
+                Ty::Infer(InferTy::MaybeNeverTypeVar(tv1)),
+                Ty::Infer(InferTy::MaybeNeverTypeVar(tv2)),
+            ) => {
+                // both type vars are unknown since we tried to resolve them
+                self.var_unification_table.union(*tv1, *tv2);
+                true
+            }
+
+            // The order of MaybeNeverTypeVar matters here.
+            // Unifying MaybeNeverTypeVar and TypeVar will let the latter become MaybeNeverTypeVar.
+            // Unifying MaybeNeverTypeVar and other concrete type will let the former become it.
+            (Ty::Infer(InferTy::TypeVar(tv)), other)
+            | (other, Ty::Infer(InferTy::TypeVar(tv)))
+            | (Ty::Infer(InferTy::MaybeNeverTypeVar(tv)), other)
+            | (other, Ty::Infer(InferTy::MaybeNeverTypeVar(tv)))
+            | (Ty::Infer(InferTy::IntVar(tv)), other @ ty_app!(TypeCtor::Int(_)))
+            | (other @ ty_app!(TypeCtor::Int(_)), Ty::Infer(InferTy::IntVar(tv)))
+            | (Ty::Infer(InferTy::FloatVar(tv)), other @ ty_app!(TypeCtor::Float(_)))
+            | (other @ ty_app!(TypeCtor::Float(_)), Ty::Infer(InferTy::FloatVar(tv))) => {
+                // the type var is unknown since we tried to resolve it
+                self.var_unification_table.union_value(*tv, TypeVarValue::Known(other.clone()));
+                true
+            }
+
+            _ => false,
+        }
+    }
+
+    /// If `ty` is a type variable with known type, returns that type;
+    /// otherwise, return ty.
+    pub fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> {
+        let mut ty = Cow::Borrowed(ty);
+        // The type variable could resolve to a int/float variable. Hence try
+        // resolving up to three times; each type of variable shouldn't occur
+        // more than once
+        for i in 0..3 {
+            if i > 0 {
+                tested_by!(type_var_resolves_to_int_var);
+            }
+            match &*ty {
+                Ty::Infer(tv) => {
+                    let inner = tv.to_inner();
+                    match self.var_unification_table.inlined_probe_value(inner).known() {
+                        Some(known_ty) => {
+                            // The known_ty can't be a type var itself
+                            ty = Cow::Owned(known_ty.clone());
+                        }
+                        _ => return ty,
+                    }
+                }
+                _ => return ty,
+            }
+        }
+        log::error!("Inference variable still not resolved: {:?}", ty);
+        ty
+    }
+
+    /// Resolves the type as far as currently possible, replacing type variables
+    /// by their known types. All types returned by the infer_* functions should
+    /// be resolved as far as possible, i.e. contain no type variables with
+    /// known type.
+    fn resolve_ty_as_possible_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
+        ty.fold(&mut |ty| match ty {
+            Ty::Infer(tv) => {
+                let inner = tv.to_inner();
+                if tv_stack.contains(&inner) {
+                    tested_by!(type_var_cycles_resolve_as_possible);
+                    // recursive type
+                    return tv.fallback_value();
+                }
+                if let Some(known_ty) =
+                    self.var_unification_table.inlined_probe_value(inner).known()
+                {
+                    // known_ty may contain other variables that are known by now
+                    tv_stack.push(inner);
+                    let result = self.resolve_ty_as_possible_inner(tv_stack, known_ty.clone());
+                    tv_stack.pop();
+                    result
+                } else {
+                    ty
+                }
+            }
+            _ => ty,
+        })
+    }
+
+    /// Resolves the type completely; type variables without known type are
+    /// replaced by Ty::Unknown.
+    fn resolve_ty_completely_inner(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
+        ty.fold(&mut |ty| match ty {
+            Ty::Infer(tv) => {
+                let inner = tv.to_inner();
+                if tv_stack.contains(&inner) {
+                    tested_by!(type_var_cycles_resolve_completely);
+                    // recursive type
+                    return tv.fallback_value();
+                }
+                if let Some(known_ty) =
+                    self.var_unification_table.inlined_probe_value(inner).known()
+                {
+                    // known_ty may contain other variables that are known by now
+                    tv_stack.push(inner);
+                    let result = self.resolve_ty_completely_inner(tv_stack, known_ty.clone());
+                    tv_stack.pop();
+                    result
+                } else {
+                    tv.fallback_value()
+                }
+            }
+            _ => ty,
+        })
+    }
+}
+
+/// The ID of a type variable.
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
+pub struct TypeVarId(pub(super) u32);
+
+impl UnifyKey for TypeVarId {
+    type Value = TypeVarValue;
+
+    fn index(&self) -> u32 {
+        self.0
+    }
+
+    fn from_index(i: u32) -> Self {
+        TypeVarId(i)
+    }
+
+    fn tag() -> &'static str {
+        "TypeVarId"
+    }
+}
+
+/// The value of a type variable: either we already know the type, or we don't
+/// know it yet.
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub enum TypeVarValue {
+    Known(Ty),
+    Unknown,
+}
+
+impl TypeVarValue {
+    fn known(&self) -> Option<&Ty> {
+        match self {
+            TypeVarValue::Known(ty) => Some(ty),
+            TypeVarValue::Unknown => None,
+        }
+    }
+}
+
+impl UnifyValue for TypeVarValue {
+    type Error = NoError;
+
+    fn unify_values(value1: &Self, value2: &Self) -> Result<Self, NoError> {
+        match (value1, value2) {
+            // We should never equate two type variables, both of which have
+            // known types. Instead, we recursively equate those types.
+            (TypeVarValue::Known(t1), TypeVarValue::Known(t2)) => panic!(
+                "equating two type variables, both of which have known types: {:?} and {:?}",
+                t1, t2
+            ),
+
+            // If one side is known, prefer that one.
+            (TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()),
+            (TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()),
+
+            (TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown),
         }
     }
 }