]> git.lizzy.rs Git - rust.git/blob - src/librustc/middle/typeck/infer/unify.rs
78c841afa609bbf318e8e86904fc6e48cb71d664
[rust.git] / src / librustc / middle / typeck / infer / unify.rs
1 // Copyright 2012-2014 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11
12 use std::collections::SmallIntMap;
13
14 use middle::ty::{Vid, expected_found, IntVarValue};
15 use middle::ty;
16 use middle::typeck::infer::{Bounds, uok, ures};
17 use middle::typeck::infer::InferCtxt;
18 use middle::typeck::infer::to_str::InferStr;
19 use std::cell::RefCell;
20 use syntax::ast;
21
22 #[deriving(Clone)]
23 pub enum VarValue<V, T> {
24     Redirect(V),
25     Root(T, uint),
26 }
27
28 pub struct ValsAndBindings<V, T> {
29     pub vals: SmallIntMap<VarValue<V, T>>,
30     pub bindings: Vec<(V, VarValue<V, T>)> ,
31 }
32
33 impl<V:Clone, T:Clone> ValsAndBindings<V, T> {
34     pub fn new() -> ValsAndBindings<V, T> {
35         ValsAndBindings {
36             vals: SmallIntMap::new(),
37             bindings: Vec::new()
38         }
39     }
40 }
41
42 pub struct Node<V, T> {
43     pub root: V,
44     pub possible_types: T,
45     pub rank: uint,
46 }
47
48 pub trait UnifyVid<T> {
49     fn appropriate_vals_and_bindings<'v>(infcx: &'v InferCtxt)
50                                      -> &'v RefCell<ValsAndBindings<Self, T>>;
51 }
52
53 pub trait UnifyInferCtxtMethods {
54     fn get<T:Clone,
55            V:Clone + PartialEq + Vid + UnifyVid<T>>(
56            &self,
57            vid: V)
58            -> Node<V, T>;
59     fn set<T:Clone + InferStr,
60            V:Clone + Vid + ToStr + UnifyVid<T>>(
61            &self,
62            vid: V,
63            new_v: VarValue<V, T>);
64     fn unify<T:Clone + InferStr,
65              V:Clone + Vid + ToStr + UnifyVid<T>>(
66              &self,
67              node_a: &Node<V, T>,
68              node_b: &Node<V, T>)
69              -> (V, uint);
70 }
71
72 impl<'a> UnifyInferCtxtMethods for InferCtxt<'a> {
73     fn get<T:Clone,
74            V:Clone + PartialEq + Vid + UnifyVid<T>>(
75            &self,
76            vid: V)
77            -> Node<V, T> {
78         /*!
79          *
80          * Find the root node for `vid`. This uses the standard
81          * union-find algorithm with path compression:
82          * http://en.wikipedia.org/wiki/Disjoint-set_data_structure
83          */
84
85         let tcx = self.tcx;
86         let vb = UnifyVid::appropriate_vals_and_bindings(self);
87         return helper(tcx, &mut *vb.borrow_mut(), vid);
88
89         fn helper<T:Clone, V:Clone+PartialEq+Vid>(
90             tcx: &ty::ctxt,
91             vb: &mut ValsAndBindings<V,T>,
92             vid: V) -> Node<V, T>
93         {
94             let vid_u = vid.to_uint();
95             let var_val = match vb.vals.find(&vid_u) {
96                 Some(&ref var_val) => (*var_val).clone(),
97                 None => {
98                     tcx.sess.bug(format!(
99                         "failed lookup of vid `{}`", vid_u).as_slice());
100                 }
101             };
102             match var_val {
103                 Redirect(vid) => {
104                     let node: Node<V,T> = helper(tcx, vb, vid.clone());
105                     if node.root != vid {
106                         // Path compression
107                         vb.vals.insert(vid.to_uint(),
108                                        Redirect(node.root.clone()));
109                     }
110                     node
111                 }
112                 Root(pt, rk) => {
113                     Node {root: vid, possible_types: pt, rank: rk}
114                 }
115             }
116         }
117     }
118
119     fn set<T:Clone + InferStr,
120            V:Clone + Vid + ToStr + UnifyVid<T>>(
121            &self,
122            vid: V,
123            new_v: VarValue<V, T>) {
124         /*!
125          *
126          * Sets the value for `vid` to `new_v`.  `vid` MUST be a root node!
127          */
128
129         debug!("Updating variable {} to {}",
130                vid.to_str(), new_v.inf_str(self));
131
132         let vb = UnifyVid::appropriate_vals_and_bindings(self);
133         let mut vb = vb.borrow_mut();
134         let old_v = (*vb.vals.get(&vid.to_uint())).clone();
135         vb.bindings.push((vid.clone(), old_v));
136         vb.vals.insert(vid.to_uint(), new_v);
137     }
138
139     fn unify<T:Clone + InferStr,
140              V:Clone + Vid + ToStr + UnifyVid<T>>(
141              &self,
142              node_a: &Node<V, T>,
143              node_b: &Node<V, T>)
144              -> (V, uint) {
145         // Rank optimization: if you don't know what it is, check
146         // out <http://en.wikipedia.org/wiki/Disjoint-set_data_structure>
147
148         debug!("unify(node_a(id={:?}, rank={:?}), \
149                 node_b(id={:?}, rank={:?}))",
150                node_a.root, node_a.rank,
151                node_b.root, node_b.rank);
152
153         if node_a.rank > node_b.rank {
154             // a has greater rank, so a should become b's parent,
155             // i.e., b should redirect to a.
156             self.set(node_b.root.clone(), Redirect(node_a.root.clone()));
157             (node_a.root.clone(), node_a.rank)
158         } else if node_a.rank < node_b.rank {
159             // b has greater rank, so a should redirect to b.
160             self.set(node_a.root.clone(), Redirect(node_b.root.clone()));
161             (node_b.root.clone(), node_b.rank)
162         } else {
163             // If equal, redirect one to the other and increment the
164             // other's rank.
165             assert_eq!(node_a.rank, node_b.rank);
166             self.set(node_b.root.clone(), Redirect(node_a.root.clone()));
167             (node_a.root.clone(), node_a.rank + 1)
168         }
169     }
170
171 }
172
173 // ______________________________________________________________________
174 // Code to handle simple variables like ints, floats---anything that
175 // doesn't have a subtyping relationship we need to worry about.
176
177 pub trait SimplyUnifiable {
178     fn to_type_err(expected_found<Self>) -> ty::type_err;
179 }
180
181 pub fn mk_err<T:SimplyUnifiable>(a_is_expected: bool,
182                                  a_t: T,
183                                  b_t: T) -> ures {
184     if a_is_expected {
185         Err(SimplyUnifiable::to_type_err(
186             ty::expected_found {expected: a_t, found: b_t}))
187     } else {
188         Err(SimplyUnifiable::to_type_err(
189             ty::expected_found {expected: b_t, found: a_t}))
190     }
191 }
192
193 pub trait InferCtxtMethods {
194     fn simple_vars<T:Clone + PartialEq + InferStr + SimplyUnifiable,
195                    V:Clone + PartialEq + Vid + ToStr + UnifyVid<Option<T>>>(
196                    &self,
197                    a_is_expected: bool,
198                    a_id: V,
199                    b_id: V)
200                    -> ures;
201     fn simple_var_t<T:Clone + PartialEq + InferStr + SimplyUnifiable,
202                     V:Clone + PartialEq + Vid + ToStr + UnifyVid<Option<T>>>(
203                     &self,
204                     a_is_expected: bool,
205                     a_id: V,
206                     b: T)
207                     -> ures;
208 }
209
210 impl<'a> InferCtxtMethods for InferCtxt<'a> {
211     fn simple_vars<T:Clone + PartialEq + InferStr + SimplyUnifiable,
212                    V:Clone + PartialEq + Vid + ToStr + UnifyVid<Option<T>>>(
213                    &self,
214                    a_is_expected: bool,
215                    a_id: V,
216                    b_id: V)
217                    -> ures {
218         /*!
219          *
220          * Unifies two simple variables.  Because simple variables do
221          * not have any subtyping relationships, if both variables
222          * have already been associated with a value, then those two
223          * values must be the same. */
224
225         let node_a = self.get(a_id);
226         let node_b = self.get(b_id);
227         let a_id = node_a.root.clone();
228         let b_id = node_b.root.clone();
229
230         if a_id == b_id { return uok(); }
231
232         let combined = match (&node_a.possible_types, &node_b.possible_types)
233         {
234             (&None, &None) => None,
235             (&Some(ref v), &None) | (&None, &Some(ref v)) => {
236                 Some((*v).clone())
237             }
238             (&Some(ref v1), &Some(ref v2)) => {
239                 if *v1 != *v2 {
240                     return mk_err(a_is_expected, (*v1).clone(), (*v2).clone())
241                 }
242                 Some((*v1).clone())
243             }
244         };
245
246         let (new_root, new_rank) = self.unify(&node_a, &node_b);
247         self.set(new_root, Root(combined, new_rank));
248         return uok();
249     }
250
251     fn simple_var_t<T:Clone + PartialEq + InferStr + SimplyUnifiable,
252                     V:Clone + PartialEq + Vid + ToStr + UnifyVid<Option<T>>>(
253                     &self,
254                     a_is_expected: bool,
255                     a_id: V,
256                     b: T)
257                     -> ures {
258         /*!
259          *
260          * Sets the value of the variable `a_id` to `b`.  Because
261          * simple variables do not have any subtyping relationships,
262          * if `a_id` already has a value, it must be the same as
263          * `b`. */
264
265         let node_a = self.get(a_id);
266         let a_id = node_a.root.clone();
267
268         match node_a.possible_types {
269             None => {
270                 self.set(a_id, Root(Some(b), node_a.rank));
271                 return uok();
272             }
273
274             Some(ref a_t) => {
275                 if *a_t == b {
276                     return uok();
277                 } else {
278                     return mk_err(a_is_expected, (*a_t).clone(), b);
279                 }
280             }
281         }
282     }
283 }
284
285 // ______________________________________________________________________
286
287 impl UnifyVid<Bounds<ty::t>> for ty::TyVid {
288     fn appropriate_vals_and_bindings<'v>(infcx: &'v InferCtxt)
289         -> &'v RefCell<ValsAndBindings<ty::TyVid, Bounds<ty::t>>> {
290         return &infcx.ty_var_bindings;
291     }
292 }
293
294 impl UnifyVid<Option<IntVarValue>> for ty::IntVid {
295     fn appropriate_vals_and_bindings<'v>(infcx: &'v InferCtxt)
296         -> &'v RefCell<ValsAndBindings<ty::IntVid, Option<IntVarValue>>> {
297         return &infcx.int_var_bindings;
298     }
299 }
300
301 impl SimplyUnifiable for IntVarValue {
302     fn to_type_err(err: expected_found<IntVarValue>) -> ty::type_err {
303         return ty::terr_int_mismatch(err);
304     }
305 }
306
307 impl UnifyVid<Option<ast::FloatTy>> for ty::FloatVid {
308     fn appropriate_vals_and_bindings<'v>(infcx: &'v InferCtxt)
309         -> &'v RefCell<ValsAndBindings<ty::FloatVid, Option<ast::FloatTy>>> {
310         return &infcx.float_var_bindings;
311     }
312 }
313
314 impl SimplyUnifiable for ast::FloatTy {
315     fn to_type_err(err: expected_found<ast::FloatTy>) -> ty::type_err {
316         return ty::terr_float_mismatch(err);
317     }
318 }