]> git.lizzy.rs Git - rust.git/blob - crates/ra_hir/src/ty/infer.rs
Implement support for type aliases
[rust.git] / crates / ra_hir / src / ty / infer.rs
1 //! Type inference, i.e. the process of walking through the code and determining
2 //! the type of each expression and pattern.
3 //!
4 //! For type inference, compare the implementations in rustc (the various
5 //! check_* methods in librustc_typeck/check/mod.rs are a good entry point) and
6 //! IntelliJ-Rust (org.rust.lang.core.types.infer). Our entry point for
7 //! inference here is the `infer` function, which infers the types of all
8 //! expressions in a given function.
9 //!
10 //! During inference, types (i.e. the `Ty` struct) can contain type 'variables'
11 //! which represent currently unknown types; as we walk through the expressions,
12 //! we might determine that certain variables need to be equal to each other, or
13 //! to certain types. To record this, we use the union-find implementation from
14 //! the `ena` crate, which is extracted from rustc.
15
16 use std::borrow::Cow;
17 use std::iter::repeat;
18 use std::ops::Index;
19 use std::sync::Arc;
20 use std::mem;
21
22 use ena::unify::{InPlaceUnificationTable, UnifyKey, UnifyValue, NoError};
23 use ra_arena::map::ArenaMap;
24 use rustc_hash::FxHashMap;
25
26 use test_utils::tested_by;
27
28 use crate::{
29     Function, StructField, Path, Name,
30     FnSignature, AdtDef,
31     HirDatabase,
32     type_ref::{TypeRef, Mutability},
33     expr::{Body, Expr, BindingAnnotation, Literal, ExprId, Pat, PatId, UnaryOp, BinaryOp, Statement, FieldPat, self},
34     generics::GenericParams,
35     path::{GenericArgs, GenericArg},
36     adt::VariantDef,
37     resolve::{Resolver, Resolution},
38     nameres::Namespace
39 };
40 use super::{Ty, TypableDef, Substs, primitive, op};
41
42 /// The entry point of type inference.
43 pub fn infer(db: &impl HirDatabase, func: Function) -> Arc<InferenceResult> {
44     db.check_canceled();
45     let body = func.body(db);
46     let resolver = func.resolver(db);
47     let mut ctx = InferenceContext::new(db, body, resolver);
48
49     let signature = func.signature(db);
50     ctx.collect_fn_signature(&signature);
51
52     ctx.infer_body();
53
54     Arc::new(ctx.resolve_all())
55 }
56
57 /// The result of type inference: A mapping from expressions and patterns to types.
58 #[derive(Clone, PartialEq, Eq, Debug)]
59 pub struct InferenceResult {
60     /// For each method call expr, records the function it resolves to.
61     method_resolutions: FxHashMap<ExprId, Function>,
62     /// For each field access expr, records the field it resolves to.
63     field_resolutions: FxHashMap<ExprId, StructField>,
64     pub(super) type_of_expr: ArenaMap<ExprId, Ty>,
65     pub(super) type_of_pat: ArenaMap<PatId, Ty>,
66 }
67
68 impl InferenceResult {
69     pub fn method_resolution(&self, expr: ExprId) -> Option<Function> {
70         self.method_resolutions.get(&expr).map(|it| *it)
71     }
72     pub fn field_resolution(&self, expr: ExprId) -> Option<StructField> {
73         self.field_resolutions.get(&expr).map(|it| *it)
74     }
75 }
76
77 impl Index<ExprId> for InferenceResult {
78     type Output = Ty;
79
80     fn index(&self, expr: ExprId) -> &Ty {
81         self.type_of_expr.get(expr).unwrap_or(&Ty::Unknown)
82     }
83 }
84
85 impl Index<PatId> for InferenceResult {
86     type Output = Ty;
87
88     fn index(&self, pat: PatId) -> &Ty {
89         self.type_of_pat.get(pat).unwrap_or(&Ty::Unknown)
90     }
91 }
92
93 /// The inference context contains all information needed during type inference.
94 #[derive(Clone, Debug)]
95 struct InferenceContext<'a, D: HirDatabase> {
96     db: &'a D,
97     body: Arc<Body>,
98     resolver: Resolver,
99     var_unification_table: InPlaceUnificationTable<TypeVarId>,
100     method_resolutions: FxHashMap<ExprId, Function>,
101     field_resolutions: FxHashMap<ExprId, StructField>,
102     type_of_expr: ArenaMap<ExprId, Ty>,
103     type_of_pat: ArenaMap<PatId, Ty>,
104     /// The return type of the function being inferred.
105     return_ty: Ty,
106 }
107
108 impl<'a, D: HirDatabase> InferenceContext<'a, D> {
109     fn new(db: &'a D, body: Arc<Body>, resolver: Resolver) -> Self {
110         InferenceContext {
111             method_resolutions: FxHashMap::default(),
112             field_resolutions: FxHashMap::default(),
113             type_of_expr: ArenaMap::default(),
114             type_of_pat: ArenaMap::default(),
115             var_unification_table: InPlaceUnificationTable::new(),
116             return_ty: Ty::Unknown, // set in collect_fn_signature
117             db,
118             body,
119             resolver,
120         }
121     }
122
123     fn resolve_all(mut self) -> InferenceResult {
124         let mut tv_stack = Vec::new();
125         let mut expr_types = mem::replace(&mut self.type_of_expr, ArenaMap::default());
126         for ty in expr_types.values_mut() {
127             let resolved = self.resolve_ty_completely(&mut tv_stack, mem::replace(ty, Ty::Unknown));
128             *ty = resolved;
129         }
130         let mut pat_types = mem::replace(&mut self.type_of_pat, ArenaMap::default());
131         for ty in pat_types.values_mut() {
132             let resolved = self.resolve_ty_completely(&mut tv_stack, mem::replace(ty, Ty::Unknown));
133             *ty = resolved;
134         }
135         InferenceResult {
136             method_resolutions: self.method_resolutions,
137             field_resolutions: self.field_resolutions,
138             type_of_expr: expr_types,
139             type_of_pat: pat_types,
140         }
141     }
142
143     fn write_expr_ty(&mut self, expr: ExprId, ty: Ty) {
144         self.type_of_expr.insert(expr, ty);
145     }
146
147     fn write_method_resolution(&mut self, expr: ExprId, func: Function) {
148         self.method_resolutions.insert(expr, func);
149     }
150
151     fn write_field_resolution(&mut self, expr: ExprId, field: StructField) {
152         self.field_resolutions.insert(expr, field);
153     }
154
155     fn write_pat_ty(&mut self, pat: PatId, ty: Ty) {
156         self.type_of_pat.insert(pat, ty);
157     }
158
159     fn make_ty(&mut self, type_ref: &TypeRef) -> Ty {
160         let ty = Ty::from_hir(
161             self.db,
162             // TODO use right resolver for block
163             &self.resolver,
164             type_ref,
165         );
166         let ty = self.insert_type_vars(ty);
167         ty
168     }
169
170     fn unify_substs(&mut self, substs1: &Substs, substs2: &Substs, depth: usize) -> bool {
171         substs1.0.iter().zip(substs2.0.iter()).all(|(t1, t2)| self.unify_inner(t1, t2, depth))
172     }
173
174     fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
175         self.unify_inner(ty1, ty2, 0)
176     }
177
178     fn unify_inner(&mut self, ty1: &Ty, ty2: &Ty, depth: usize) -> bool {
179         if depth > 1000 {
180             // prevent stackoverflows
181             panic!("infinite recursion in unification");
182         }
183         if ty1 == ty2 {
184             return true;
185         }
186         // try to resolve type vars first
187         let ty1 = self.resolve_ty_shallow(ty1);
188         let ty2 = self.resolve_ty_shallow(ty2);
189         match (&*ty1, &*ty2) {
190             (Ty::Unknown, ..) => true,
191             (.., Ty::Unknown) => true,
192             (Ty::Int(t1), Ty::Int(t2)) => match (t1, t2) {
193                 (primitive::UncertainIntTy::Unknown, _)
194                 | (_, primitive::UncertainIntTy::Unknown) => true,
195                 _ => t1 == t2,
196             },
197             (Ty::Float(t1), Ty::Float(t2)) => match (t1, t2) {
198                 (primitive::UncertainFloatTy::Unknown, _)
199                 | (_, primitive::UncertainFloatTy::Unknown) => true,
200                 _ => t1 == t2,
201             },
202             (Ty::Bool, _) | (Ty::Str, _) | (Ty::Never, _) | (Ty::Char, _) => ty1 == ty2,
203             (
204                 Ty::Adt { def_id: def_id1, substs: substs1, .. },
205                 Ty::Adt { def_id: def_id2, substs: substs2, .. },
206             ) if def_id1 == def_id2 => self.unify_substs(substs1, substs2, depth + 1),
207             (Ty::Slice(t1), Ty::Slice(t2)) => self.unify_inner(t1, t2, depth + 1),
208             (Ty::RawPtr(t1, m1), Ty::RawPtr(t2, m2)) if m1 == m2 => {
209                 self.unify_inner(t1, t2, depth + 1)
210             }
211             (Ty::Ref(t1, m1), Ty::Ref(t2, m2)) if m1 == m2 => self.unify_inner(t1, t2, depth + 1),
212             (Ty::FnPtr(sig1), Ty::FnPtr(sig2)) if sig1 == sig2 => true,
213             (Ty::Tuple(ts1), Ty::Tuple(ts2)) if ts1.len() == ts2.len() => {
214                 ts1.iter().zip(ts2.iter()).all(|(t1, t2)| self.unify_inner(t1, t2, depth + 1))
215             }
216             (Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2)))
217             | (Ty::Infer(InferTy::IntVar(tv1)), Ty::Infer(InferTy::IntVar(tv2)))
218             | (Ty::Infer(InferTy::FloatVar(tv1)), Ty::Infer(InferTy::FloatVar(tv2))) => {
219                 // both type vars are unknown since we tried to resolve them
220                 self.var_unification_table.union(*tv1, *tv2);
221                 true
222             }
223             (Ty::Infer(InferTy::TypeVar(tv)), other)
224             | (other, Ty::Infer(InferTy::TypeVar(tv)))
225             | (Ty::Infer(InferTy::IntVar(tv)), other)
226             | (other, Ty::Infer(InferTy::IntVar(tv)))
227             | (Ty::Infer(InferTy::FloatVar(tv)), other)
228             | (other, Ty::Infer(InferTy::FloatVar(tv))) => {
229                 // the type var is unknown since we tried to resolve it
230                 self.var_unification_table.union_value(*tv, TypeVarValue::Known(other.clone()));
231                 true
232             }
233             _ => false,
234         }
235     }
236
237     fn new_type_var(&mut self) -> Ty {
238         Ty::Infer(InferTy::TypeVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
239     }
240
241     fn new_integer_var(&mut self) -> Ty {
242         Ty::Infer(InferTy::IntVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
243     }
244
245     fn new_float_var(&mut self) -> Ty {
246         Ty::Infer(InferTy::FloatVar(self.var_unification_table.new_key(TypeVarValue::Unknown)))
247     }
248
249     /// Replaces Ty::Unknown by a new type var, so we can maybe still infer it.
250     fn insert_type_vars_shallow(&mut self, ty: Ty) -> Ty {
251         match ty {
252             Ty::Unknown => self.new_type_var(),
253             Ty::Int(primitive::UncertainIntTy::Unknown) => self.new_integer_var(),
254             Ty::Float(primitive::UncertainFloatTy::Unknown) => self.new_float_var(),
255             _ => ty,
256         }
257     }
258
259     fn insert_type_vars(&mut self, ty: Ty) -> Ty {
260         ty.fold(&mut |ty| self.insert_type_vars_shallow(ty))
261     }
262
263     /// Resolves the type as far as currently possible, replacing type variables
264     /// by their known types. All types returned by the infer_* functions should
265     /// be resolved as far as possible, i.e. contain no type variables with
266     /// known type.
267     fn resolve_ty_as_possible(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
268         ty.fold(&mut |ty| match ty {
269             Ty::Infer(tv) => {
270                 let inner = tv.to_inner();
271                 if tv_stack.contains(&inner) {
272                     tested_by!(type_var_cycles_resolve_as_possible);
273                     // recursive type
274                     return tv.fallback_value();
275                 }
276                 if let Some(known_ty) = self.var_unification_table.probe_value(inner).known() {
277                     // known_ty may contain other variables that are known by now
278                     tv_stack.push(inner);
279                     let result = self.resolve_ty_as_possible(tv_stack, known_ty.clone());
280                     tv_stack.pop();
281                     result
282                 } else {
283                     ty
284                 }
285             }
286             _ => ty,
287         })
288     }
289
290     /// If `ty` is a type variable with known type, returns that type;
291     /// otherwise, return ty.
292     fn resolve_ty_shallow<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> {
293         let mut ty = Cow::Borrowed(ty);
294         // The type variable could resolve to a int/float variable. Hence try
295         // resolving up to three times; each type of variable shouldn't occur
296         // more than once
297         for i in 0..3 {
298             if i > 0 {
299                 tested_by!(type_var_resolves_to_int_var);
300             }
301             match &*ty {
302                 Ty::Infer(tv) => {
303                     let inner = tv.to_inner();
304                     match self.var_unification_table.probe_value(inner).known() {
305                         Some(known_ty) => {
306                             // The known_ty can't be a type var itself
307                             ty = Cow::Owned(known_ty.clone());
308                         }
309                         _ => return ty,
310                     }
311                 }
312                 _ => return ty,
313             }
314         }
315         log::error!("Inference variable still not resolved: {:?}", ty);
316         ty
317     }
318
319     /// Resolves the type completely; type variables without known type are
320     /// replaced by Ty::Unknown.
321     fn resolve_ty_completely(&mut self, tv_stack: &mut Vec<TypeVarId>, ty: Ty) -> Ty {
322         ty.fold(&mut |ty| match ty {
323             Ty::Infer(tv) => {
324                 let inner = tv.to_inner();
325                 if tv_stack.contains(&inner) {
326                     tested_by!(type_var_cycles_resolve_completely);
327                     // recursive type
328                     return tv.fallback_value();
329                 }
330                 if let Some(known_ty) = self.var_unification_table.probe_value(inner).known() {
331                     // known_ty may contain other variables that are known by now
332                     tv_stack.push(inner);
333                     let result = self.resolve_ty_completely(tv_stack, known_ty.clone());
334                     tv_stack.pop();
335                     result
336                 } else {
337                     tv.fallback_value()
338                 }
339             }
340             _ => ty,
341         })
342     }
343
344     fn infer_path_expr(&mut self, resolver: &Resolver, path: &Path) -> Option<Ty> {
345         let resolved = resolver.resolve_path_segments(self.db, &path);
346
347         let (def, remaining_index) = resolved.into_inner();
348
349         log::debug!(
350             "path {:?} resolved to {:?} with remaining index {:?}",
351             path,
352             def,
353             remaining_index
354         );
355
356         // if the remaining_index is None, we expect the path
357         // to be fully resolved, in this case we continue with
358         // the default by attempting to `take_values´ from the resolution.
359         // Otherwise the path was partially resolved, which means
360         // we might have resolved into a type for which
361         // we may find some associated item starting at the
362         // path.segment pointed to by `remaining_index´
363         let mut resolved =
364             if remaining_index.is_none() { def.take_values()? } else { def.take_types()? };
365
366         let remaining_index = remaining_index.unwrap_or(path.segments.len());
367
368         // resolve intermediate segments
369         for segment in &path.segments[remaining_index..] {
370             let ty = match resolved {
371                 Resolution::Def(def) => {
372                     let typable: Option<TypableDef> = def.into();
373                     let typable = typable?;
374
375                     let substs =
376                         Ty::substs_from_path_segment(self.db, &self.resolver, segment, typable);
377                     self.db.type_for_def(typable, Namespace::Types).apply_substs(substs)
378                 }
379                 Resolution::LocalBinding(_) => {
380                     // can't have a local binding in an associated item path
381                     return None;
382                 }
383                 Resolution::GenericParam(..) => {
384                     // TODO associated item of generic param
385                     return None;
386                 }
387                 Resolution::SelfType(_) => {
388                     // TODO associated item of self type
389                     return None;
390                 }
391             };
392
393             // Attempt to find an impl_item for the type which has a name matching
394             // the current segment
395             log::debug!("looking for path segment: {:?}", segment);
396             let item = ty.iterate_impl_items(self.db, |item| match item {
397                 crate::ImplItem::Method(func) => {
398                     let sig = func.signature(self.db);
399                     if segment.name == *sig.name() {
400                         return Some(func);
401                     }
402                     None
403                 }
404
405                 // TODO: Resolve associated const
406                 crate::ImplItem::Const(_) => None,
407
408                 // TODO: Resolve associated types
409                 crate::ImplItem::Type(_) => None,
410             })?;
411             resolved = Resolution::Def(item.into());
412         }
413
414         match resolved {
415             Resolution::Def(def) => {
416                 let typable: Option<TypableDef> = def.into();
417                 let typable = typable?;
418
419                 let substs = Ty::substs_from_path(self.db, &self.resolver, path, typable);
420                 let ty = self.db.type_for_def(typable, Namespace::Values).apply_substs(substs);
421                 let ty = self.insert_type_vars(ty);
422                 Some(ty)
423             }
424             Resolution::LocalBinding(pat) => {
425                 let ty = self.type_of_pat.get(pat)?;
426                 let ty = self.resolve_ty_as_possible(&mut vec![], ty.clone());
427                 Some(ty)
428             }
429             Resolution::GenericParam(..) => {
430                 // generic params can't refer to values... yet
431                 None
432             }
433             Resolution::SelfType(_) => {
434                 log::error!("path expr {:?} resolved to Self type in values ns", path);
435                 None
436             }
437         }
438     }
439
440     fn resolve_variant(&mut self, path: Option<&Path>) -> (Ty, Option<VariantDef>) {
441         let path = match path {
442             Some(path) => path,
443             None => return (Ty::Unknown, None),
444         };
445         let resolver = &self.resolver;
446         let typable: Option<TypableDef> = match resolver.resolve_path(self.db, &path).take_types() {
447             Some(Resolution::Def(def)) => def.into(),
448             Some(Resolution::LocalBinding(..)) => {
449                 // this cannot happen
450                 log::error!("path resolved to local binding in type ns");
451                 return (Ty::Unknown, None);
452             }
453             Some(Resolution::GenericParam(..)) => {
454                 // generic params can't be used in struct literals
455                 return (Ty::Unknown, None);
456             }
457             Some(Resolution::SelfType(..)) => {
458                 // TODO this is allowed in an impl for a struct, handle this
459                 return (Ty::Unknown, None);
460             }
461             None => return (Ty::Unknown, None),
462         };
463         let def = match typable {
464             None => return (Ty::Unknown, None),
465             Some(it) => it,
466         };
467         // TODO remove the duplication between here and `Ty::from_path`?
468         let substs = Ty::substs_from_path(self.db, resolver, path, def);
469         match def {
470             TypableDef::Struct(s) => {
471                 let ty = s.ty(self.db);
472                 let ty = self.insert_type_vars(ty.apply_substs(substs));
473                 (ty, Some(s.into()))
474             }
475             TypableDef::EnumVariant(var) => {
476                 let ty = var.parent_enum(self.db).ty(self.db);
477                 let ty = self.insert_type_vars(ty.apply_substs(substs));
478                 (ty, Some(var.into()))
479             }
480             TypableDef::Type(_) | TypableDef::Function(_) | TypableDef::Enum(_) => {
481                 (Ty::Unknown, None)
482             }
483         }
484     }
485
486     fn infer_tuple_struct_pat(
487         &mut self,
488         path: Option<&Path>,
489         subpats: &[PatId],
490         expected: &Ty,
491     ) -> Ty {
492         let (ty, def) = self.resolve_variant(path);
493
494         self.unify(&ty, expected);
495
496         let substs = ty.substs().unwrap_or_else(Substs::empty);
497
498         for (i, &subpat) in subpats.iter().enumerate() {
499             let expected_ty = def
500                 .and_then(|d| d.field(self.db, &Name::tuple_field_name(i)))
501                 .map_or(Ty::Unknown, |field| field.ty(self.db))
502                 .subst(&substs);
503             self.infer_pat(subpat, &expected_ty);
504         }
505
506         ty
507     }
508
509     fn infer_struct_pat(&mut self, path: Option<&Path>, subpats: &[FieldPat], expected: &Ty) -> Ty {
510         let (ty, def) = self.resolve_variant(path);
511
512         self.unify(&ty, expected);
513
514         let substs = ty.substs().unwrap_or_else(Substs::empty);
515
516         for subpat in subpats {
517             let matching_field = def.and_then(|it| it.field(self.db, &subpat.name));
518             let expected_ty =
519                 matching_field.map_or(Ty::Unknown, |field| field.ty(self.db)).subst(&substs);
520             self.infer_pat(subpat.pat, &expected_ty);
521         }
522
523         ty
524     }
525
526     fn infer_pat(&mut self, pat: PatId, expected: &Ty) -> Ty {
527         let body = Arc::clone(&self.body); // avoid borrow checker problem
528
529         let ty = match &body[pat] {
530             Pat::Tuple(ref args) => {
531                 let expectations = match *expected {
532                     Ty::Tuple(ref tuple_args) => &**tuple_args,
533                     _ => &[],
534                 };
535                 let expectations_iter = expectations.iter().chain(repeat(&Ty::Unknown));
536
537                 let inner_tys = args
538                     .iter()
539                     .zip(expectations_iter)
540                     .map(|(&pat, ty)| self.infer_pat(pat, ty))
541                     .collect::<Vec<_>>()
542                     .into();
543
544                 Ty::Tuple(inner_tys)
545             }
546             Pat::Ref { pat, mutability } => {
547                 let expectation = match *expected {
548                     Ty::Ref(ref sub_ty, exp_mut) => {
549                         if *mutability != exp_mut {
550                             // TODO: emit type error?
551                         }
552                         &**sub_ty
553                     }
554                     _ => &Ty::Unknown,
555                 };
556                 let subty = self.infer_pat(*pat, expectation);
557                 Ty::Ref(subty.into(), *mutability)
558             }
559             Pat::TupleStruct { path: ref p, args: ref subpats } => {
560                 self.infer_tuple_struct_pat(p.as_ref(), subpats, expected)
561             }
562             Pat::Struct { path: ref p, args: ref fields } => {
563                 self.infer_struct_pat(p.as_ref(), fields, expected)
564             }
565             Pat::Path(path) => {
566                 // TODO use correct resolver for the surrounding expression
567                 let resolver = self.resolver.clone();
568                 self.infer_path_expr(&resolver, &path).unwrap_or(Ty::Unknown)
569             }
570             Pat::Bind { mode, name: _name, subpat } => {
571                 let inner_ty = if let Some(subpat) = subpat {
572                     self.infer_pat(*subpat, expected)
573                 } else {
574                     expected.clone()
575                 };
576                 let inner_ty = self.insert_type_vars_shallow(inner_ty);
577
578                 let bound_ty = match mode {
579                     BindingAnnotation::Ref => Ty::Ref(inner_ty.clone().into(), Mutability::Shared),
580                     BindingAnnotation::RefMut => Ty::Ref(inner_ty.clone().into(), Mutability::Mut),
581                     BindingAnnotation::Mutable | BindingAnnotation::Unannotated => inner_ty.clone(),
582                 };
583                 let bound_ty = self.resolve_ty_as_possible(&mut vec![], bound_ty);
584                 self.write_pat_ty(pat, bound_ty);
585                 return inner_ty;
586             }
587             _ => Ty::Unknown,
588         };
589         // use a new type variable if we got Ty::Unknown here
590         let ty = self.insert_type_vars_shallow(ty);
591         self.unify(&ty, expected);
592         let ty = self.resolve_ty_as_possible(&mut vec![], ty);
593         self.write_pat_ty(pat, ty.clone());
594         ty
595     }
596
597     fn substs_for_method_call(
598         &mut self,
599         def_generics: Option<Arc<GenericParams>>,
600         generic_args: &Option<GenericArgs>,
601     ) -> Substs {
602         let (parent_param_count, param_count) =
603             def_generics.map_or((0, 0), |g| (g.count_parent_params(), g.params.len()));
604         let mut substs = Vec::with_capacity(parent_param_count + param_count);
605         for _ in 0..parent_param_count {
606             substs.push(Ty::Unknown);
607         }
608         // handle provided type arguments
609         if let Some(generic_args) = generic_args {
610             // if args are provided, it should be all of them, but we can't rely on that
611             for arg in generic_args.args.iter().take(param_count) {
612                 match arg {
613                     GenericArg::Type(type_ref) => {
614                         let ty = self.make_ty(type_ref);
615                         substs.push(ty);
616                     }
617                 }
618             }
619         };
620         let supplied_params = substs.len();
621         for _ in supplied_params..parent_param_count + param_count {
622             substs.push(Ty::Unknown);
623         }
624         assert_eq!(substs.len(), parent_param_count + param_count);
625         Substs(substs.into())
626     }
627
628     fn infer_expr(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
629         let body = Arc::clone(&self.body); // avoid borrow checker problem
630         let ty = match &body[tgt_expr] {
631             Expr::Missing => Ty::Unknown,
632             Expr::If { condition, then_branch, else_branch } => {
633                 // if let is desugared to match, so this is always simple if
634                 self.infer_expr(*condition, &Expectation::has_type(Ty::Bool));
635                 let then_ty = self.infer_expr(*then_branch, expected);
636                 match else_branch {
637                     Some(else_branch) => {
638                         self.infer_expr(*else_branch, expected);
639                     }
640                     None => {
641                         // no else branch -> unit
642                         self.unify(&then_ty, &Ty::unit()); // actually coerce
643                     }
644                 };
645                 then_ty
646             }
647             Expr::Block { statements, tail } => self.infer_block(statements, *tail, expected),
648             Expr::Loop { body } => {
649                 self.infer_expr(*body, &Expectation::has_type(Ty::unit()));
650                 // TODO handle break with value
651                 Ty::Never
652             }
653             Expr::While { condition, body } => {
654                 // while let is desugared to a match loop, so this is always simple while
655                 self.infer_expr(*condition, &Expectation::has_type(Ty::Bool));
656                 self.infer_expr(*body, &Expectation::has_type(Ty::unit()));
657                 Ty::unit()
658             }
659             Expr::For { iterable, body, pat } => {
660                 let _iterable_ty = self.infer_expr(*iterable, &Expectation::none());
661                 self.infer_pat(*pat, &Ty::Unknown);
662                 self.infer_expr(*body, &Expectation::has_type(Ty::unit()));
663                 Ty::unit()
664             }
665             Expr::Lambda { body, args, arg_types } => {
666                 assert_eq!(args.len(), arg_types.len());
667
668                 for (arg_pat, arg_type) in args.iter().zip(arg_types.iter()) {
669                     let expected = if let Some(type_ref) = arg_type {
670                         let ty = self.make_ty(type_ref);
671                         ty
672                     } else {
673                         Ty::Unknown
674                     };
675                     self.infer_pat(*arg_pat, &expected);
676                 }
677
678                 // TODO: infer lambda type etc.
679                 let _body_ty = self.infer_expr(*body, &Expectation::none());
680                 Ty::Unknown
681             }
682             Expr::Call { callee, args } => {
683                 let callee_ty = self.infer_expr(*callee, &Expectation::none());
684                 let (param_tys, ret_ty) = match &callee_ty {
685                     Ty::FnPtr(sig) => (sig.input.clone(), sig.output.clone()),
686                     Ty::FnDef { substs, sig, .. } => {
687                         let ret_ty = sig.output.clone().subst(&substs);
688                         let param_tys =
689                             sig.input.iter().map(|ty| ty.clone().subst(&substs)).collect();
690                         (param_tys, ret_ty)
691                     }
692                     _ => {
693                         // not callable
694                         // TODO report an error?
695                         (Vec::new(), Ty::Unknown)
696                     }
697                 };
698                 let param_iter = param_tys.into_iter().chain(repeat(Ty::Unknown));
699                 for (arg, param) in args.iter().zip(param_iter) {
700                     self.infer_expr(*arg, &Expectation::has_type(param));
701                 }
702                 ret_ty
703             }
704             Expr::MethodCall { receiver, args, method_name, generic_args } => {
705                 let receiver_ty = self.infer_expr(*receiver, &Expectation::none());
706                 let resolved = receiver_ty.clone().lookup_method(self.db, method_name);
707                 let (derefed_receiver_ty, method_ty, def_generics) = match resolved {
708                     Some((ty, func)) => {
709                         self.write_method_resolution(tgt_expr, func);
710                         (
711                             ty,
712                             self.db.type_for_def(func.into(), Namespace::Values),
713                             Some(func.generic_params(self.db)),
714                         )
715                     }
716                     None => (Ty::Unknown, receiver_ty, None),
717                 };
718                 let substs = self.substs_for_method_call(def_generics, generic_args);
719                 let method_ty = method_ty.apply_substs(substs);
720                 let method_ty = self.insert_type_vars(method_ty);
721                 let (expected_receiver_ty, param_tys, ret_ty) = match &method_ty {
722                     Ty::FnPtr(sig) => {
723                         if !sig.input.is_empty() {
724                             (sig.input[0].clone(), sig.input[1..].to_vec(), sig.output.clone())
725                         } else {
726                             (Ty::Unknown, Vec::new(), sig.output.clone())
727                         }
728                     }
729                     Ty::FnDef { substs, sig, .. } => {
730                         let ret_ty = sig.output.clone().subst(&substs);
731
732                         if !sig.input.is_empty() {
733                             let mut arg_iter = sig.input.iter().map(|ty| ty.clone().subst(&substs));
734                             let receiver_ty = arg_iter.next().unwrap();
735                             (receiver_ty, arg_iter.collect(), ret_ty)
736                         } else {
737                             (Ty::Unknown, Vec::new(), ret_ty)
738                         }
739                     }
740                     _ => (Ty::Unknown, Vec::new(), Ty::Unknown),
741                 };
742                 // Apply autoref so the below unification works correctly
743                 let actual_receiver_ty = match expected_receiver_ty {
744                     Ty::Ref(_, mutability) => Ty::Ref(Arc::new(derefed_receiver_ty), mutability),
745                     _ => derefed_receiver_ty,
746                 };
747                 self.unify(&expected_receiver_ty, &actual_receiver_ty);
748
749                 let param_iter = param_tys.into_iter().chain(repeat(Ty::Unknown));
750                 for (arg, param) in args.iter().zip(param_iter) {
751                     self.infer_expr(*arg, &Expectation::has_type(param));
752                 }
753                 ret_ty
754             }
755             Expr::Match { expr, arms } => {
756                 let expected = if expected.ty == Ty::Unknown {
757                     Expectation::has_type(self.new_type_var())
758                 } else {
759                     expected.clone()
760                 };
761                 let input_ty = self.infer_expr(*expr, &Expectation::none());
762
763                 for arm in arms {
764                     for &pat in &arm.pats {
765                         let _pat_ty = self.infer_pat(pat, &input_ty);
766                     }
767                     if let Some(guard_expr) = arm.guard {
768                         self.infer_expr(guard_expr, &Expectation::has_type(Ty::Bool));
769                     }
770                     self.infer_expr(arm.expr, &expected);
771                 }
772
773                 expected.ty
774             }
775             Expr::Path(p) => {
776                 // TODO this could be more efficient...
777                 let resolver = expr::resolver_for_expr(self.body.clone(), self.db, tgt_expr);
778                 self.infer_path_expr(&resolver, p).unwrap_or(Ty::Unknown)
779             }
780             Expr::Continue => Ty::Never,
781             Expr::Break { expr } => {
782                 if let Some(expr) = expr {
783                     // TODO handle break with value
784                     self.infer_expr(*expr, &Expectation::none());
785                 }
786                 Ty::Never
787             }
788             Expr::Return { expr } => {
789                 if let Some(expr) = expr {
790                     self.infer_expr(*expr, &Expectation::has_type(self.return_ty.clone()));
791                 }
792                 Ty::Never
793             }
794             Expr::StructLit { path, fields, spread } => {
795                 let (ty, def_id) = self.resolve_variant(path.as_ref());
796                 let substs = ty.substs().unwrap_or_else(Substs::empty);
797                 for field in fields {
798                     let field_ty = def_id
799                         .and_then(|it| it.field(self.db, &field.name))
800                         .map_or(Ty::Unknown, |field| field.ty(self.db))
801                         .subst(&substs);
802                     self.infer_expr(field.expr, &Expectation::has_type(field_ty));
803                 }
804                 if let Some(expr) = spread {
805                     self.infer_expr(*expr, &Expectation::has_type(ty.clone()));
806                 }
807                 ty
808             }
809             Expr::Field { expr, name } => {
810                 let receiver_ty = self.infer_expr(*expr, &Expectation::none());
811                 let ty = receiver_ty
812                     .autoderef(self.db)
813                     .find_map(|derefed_ty| match derefed_ty {
814                         Ty::Tuple(fields) => {
815                             let i = name.to_string().parse::<usize>().ok();
816                             i.and_then(|i| fields.get(i).cloned())
817                         }
818                         Ty::Adt { def_id: AdtDef::Struct(s), ref substs, .. } => {
819                             s.field(self.db, name).map(|field| {
820                                 self.write_field_resolution(tgt_expr, field);
821                                 field.ty(self.db).subst(substs)
822                             })
823                         }
824                         _ => None,
825                     })
826                     .unwrap_or(Ty::Unknown);
827                 self.insert_type_vars(ty)
828             }
829             Expr::Try { expr } => {
830                 let _inner_ty = self.infer_expr(*expr, &Expectation::none());
831                 Ty::Unknown
832             }
833             Expr::Cast { expr, type_ref } => {
834                 let _inner_ty = self.infer_expr(*expr, &Expectation::none());
835                 let cast_ty = self.make_ty(type_ref);
836                 // TODO check the cast...
837                 cast_ty
838             }
839             Expr::Ref { expr, mutability } => {
840                 let expectation = if let Ty::Ref(ref subty, expected_mutability) = expected.ty {
841                     if expected_mutability == Mutability::Mut && *mutability == Mutability::Shared {
842                         // TODO: throw type error - expected mut reference but found shared ref,
843                         // which cannot be coerced
844                     }
845                     Expectation::has_type((**subty).clone())
846                 } else {
847                     Expectation::none()
848                 };
849                 // TODO reference coercions etc.
850                 let inner_ty = self.infer_expr(*expr, &expectation);
851                 Ty::Ref(Arc::new(inner_ty), *mutability)
852             }
853             Expr::UnaryOp { expr, op } => {
854                 let inner_ty = self.infer_expr(*expr, &Expectation::none());
855                 match op {
856                     UnaryOp::Deref => {
857                         if let Some(derefed_ty) = inner_ty.builtin_deref() {
858                             derefed_ty
859                         } else {
860                             // TODO Deref::deref
861                             Ty::Unknown
862                         }
863                     }
864                     UnaryOp::Neg => {
865                         match inner_ty {
866                             Ty::Int(primitive::UncertainIntTy::Unknown)
867                             | Ty::Int(primitive::UncertainIntTy::Signed(..))
868                             | Ty::Infer(InferTy::IntVar(..))
869                             | Ty::Infer(InferTy::FloatVar(..))
870                             | Ty::Float(..) => inner_ty,
871                             // TODO: resolve ops::Neg trait
872                             _ => Ty::Unknown,
873                         }
874                     }
875                     UnaryOp::Not => {
876                         match inner_ty {
877                             Ty::Bool | Ty::Int(_) | Ty::Infer(InferTy::IntVar(..)) => inner_ty,
878                             // TODO: resolve ops::Not trait for inner_ty
879                             _ => Ty::Unknown,
880                         }
881                     }
882                 }
883             }
884             Expr::BinaryOp { lhs, rhs, op } => match op {
885                 Some(op) => {
886                     let lhs_expectation = match op {
887                         BinaryOp::BooleanAnd | BinaryOp::BooleanOr => {
888                             Expectation::has_type(Ty::Bool)
889                         }
890                         _ => Expectation::none(),
891                     };
892                     let lhs_ty = self.infer_expr(*lhs, &lhs_expectation);
893                     // TODO: find implementation of trait corresponding to operation
894                     // symbol and resolve associated `Output` type
895                     let rhs_expectation = op::binary_op_rhs_expectation(*op, lhs_ty);
896                     let rhs_ty = self.infer_expr(*rhs, &Expectation::has_type(rhs_expectation));
897
898                     // TODO: similar as above, return ty is often associated trait type
899                     op::binary_op_return_ty(*op, rhs_ty)
900                 }
901                 _ => Ty::Unknown,
902             },
903             Expr::Tuple { exprs } => {
904                 let mut ty_vec = Vec::with_capacity(exprs.len());
905                 for arg in exprs.iter() {
906                     ty_vec.push(self.infer_expr(*arg, &Expectation::none()));
907                 }
908
909                 Ty::Tuple(Arc::from(ty_vec))
910             }
911             Expr::Array { exprs } => {
912                 let elem_ty = match &expected.ty {
913                     Ty::Slice(inner) | Ty::Array(inner) => Ty::clone(&inner),
914                     _ => self.new_type_var(),
915                 };
916
917                 for expr in exprs.iter() {
918                     self.infer_expr(*expr, &Expectation::has_type(elem_ty.clone()));
919                 }
920
921                 Ty::Array(Arc::new(elem_ty))
922             }
923             Expr::Literal(lit) => match lit {
924                 Literal::Bool(..) => Ty::Bool,
925                 Literal::String(..) => Ty::Ref(Arc::new(Ty::Str), Mutability::Shared),
926                 Literal::ByteString(..) => {
927                     let byte_type = Arc::new(Ty::Int(primitive::UncertainIntTy::Unsigned(
928                         primitive::UintTy::U8,
929                     )));
930                     let slice_type = Arc::new(Ty::Slice(byte_type));
931                     Ty::Ref(slice_type, Mutability::Shared)
932                 }
933                 Literal::Char(..) => Ty::Char,
934                 Literal::Int(_v, ty) => Ty::Int(*ty),
935                 Literal::Float(_v, ty) => Ty::Float(*ty),
936             },
937         };
938         // use a new type variable if we got Ty::Unknown here
939         let ty = self.insert_type_vars_shallow(ty);
940         self.unify(&ty, &expected.ty);
941         let ty = self.resolve_ty_as_possible(&mut vec![], ty);
942         self.write_expr_ty(tgt_expr, ty.clone());
943         ty
944     }
945
946     fn infer_block(
947         &mut self,
948         statements: &[Statement],
949         tail: Option<ExprId>,
950         expected: &Expectation,
951     ) -> Ty {
952         for stmt in statements {
953             match stmt {
954                 Statement::Let { pat, type_ref, initializer } => {
955                     let decl_ty =
956                         type_ref.as_ref().map(|tr| self.make_ty(tr)).unwrap_or(Ty::Unknown);
957                     let decl_ty = self.insert_type_vars(decl_ty);
958                     let ty = if let Some(expr) = initializer {
959                         let expr_ty = self.infer_expr(*expr, &Expectation::has_type(decl_ty));
960                         expr_ty
961                     } else {
962                         decl_ty
963                     };
964
965                     self.infer_pat(*pat, &ty);
966                 }
967                 Statement::Expr(expr) => {
968                     self.infer_expr(*expr, &Expectation::none());
969                 }
970             }
971         }
972         let ty = if let Some(expr) = tail { self.infer_expr(expr, expected) } else { Ty::unit() };
973         ty
974     }
975
976     fn collect_fn_signature(&mut self, signature: &FnSignature) {
977         let body = Arc::clone(&self.body); // avoid borrow checker problem
978         for (type_ref, pat) in signature.params().iter().zip(body.params()) {
979             let ty = self.make_ty(type_ref);
980
981             self.infer_pat(*pat, &ty);
982         }
983         self.return_ty = self.make_ty(signature.ret_type());
984     }
985
986     fn infer_body(&mut self) {
987         self.infer_expr(self.body.body_expr(), &Expectation::has_type(self.return_ty.clone()));
988     }
989 }
990
991 /// The ID of a type variable.
992 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
993 pub struct TypeVarId(u32);
994
995 impl UnifyKey for TypeVarId {
996     type Value = TypeVarValue;
997
998     fn index(&self) -> u32 {
999         self.0
1000     }
1001
1002     fn from_index(i: u32) -> Self {
1003         TypeVarId(i)
1004     }
1005
1006     fn tag() -> &'static str {
1007         "TypeVarId"
1008     }
1009 }
1010
1011 /// The value of a type variable: either we already know the type, or we don't
1012 /// know it yet.
1013 #[derive(Clone, PartialEq, Eq, Debug)]
1014 pub enum TypeVarValue {
1015     Known(Ty),
1016     Unknown,
1017 }
1018
1019 impl TypeVarValue {
1020     fn known(&self) -> Option<&Ty> {
1021         match self {
1022             TypeVarValue::Known(ty) => Some(ty),
1023             TypeVarValue::Unknown => None,
1024         }
1025     }
1026 }
1027
1028 impl UnifyValue for TypeVarValue {
1029     type Error = NoError;
1030
1031     fn unify_values(value1: &Self, value2: &Self) -> Result<Self, NoError> {
1032         match (value1, value2) {
1033             // We should never equate two type variables, both of which have
1034             // known types. Instead, we recursively equate those types.
1035             (TypeVarValue::Known(t1), TypeVarValue::Known(t2)) => panic!(
1036                 "equating two type variables, both of which have known types: {:?} and {:?}",
1037                 t1, t2
1038             ),
1039
1040             // If one side is known, prefer that one.
1041             (TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()),
1042             (TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()),
1043
1044             (TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown),
1045         }
1046     }
1047 }
1048
1049 /// The kinds of placeholders we need during type inference. There's separate
1050 /// values for general types, and for integer and float variables. The latter
1051 /// two are used for inference of literal values (e.g. `100` could be one of
1052 /// several integer types).
1053 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
1054 pub enum InferTy {
1055     TypeVar(TypeVarId),
1056     IntVar(TypeVarId),
1057     FloatVar(TypeVarId),
1058 }
1059
1060 impl InferTy {
1061     fn to_inner(self) -> TypeVarId {
1062         match self {
1063             InferTy::TypeVar(ty) | InferTy::IntVar(ty) | InferTy::FloatVar(ty) => ty,
1064         }
1065     }
1066
1067     fn fallback_value(self) -> Ty {
1068         match self {
1069             InferTy::TypeVar(..) => Ty::Unknown,
1070             InferTy::IntVar(..) => {
1071                 Ty::Int(primitive::UncertainIntTy::Signed(primitive::IntTy::I32))
1072             }
1073             InferTy::FloatVar(..) => {
1074                 Ty::Float(primitive::UncertainFloatTy::Known(primitive::FloatTy::F64))
1075             }
1076         }
1077     }
1078 }
1079
1080 /// When inferring an expression, we propagate downward whatever type hint we
1081 /// are able in the form of an `Expectation`.
1082 #[derive(Clone, PartialEq, Eq, Debug)]
1083 struct Expectation {
1084     ty: Ty,
1085     // TODO: In some cases, we need to be aware whether the expectation is that
1086     // the type match exactly what we passed, or whether it just needs to be
1087     // coercible to the expected type. See Expectation::rvalue_hint in rustc.
1088 }
1089
1090 impl Expectation {
1091     /// The expectation that the type of the expression needs to equal the given
1092     /// type.
1093     fn has_type(ty: Ty) -> Self {
1094         Expectation { ty }
1095     }
1096
1097     /// This expresses no expectation on the type.
1098     fn none() -> Self {
1099         Expectation { ty: Ty::Unknown }
1100     }
1101 }