]> git.lizzy.rs Git - rust.git/blob - crates/hir_ty/src/infer/pat.rs
Implement box pattern inference
[rust.git] / crates / hir_ty / src / infer / pat.rs
1 //! Type inference for patterns.
2
3 use std::iter::repeat;
4 use std::sync::Arc;
5
6 use hir_def::{
7     expr::{BindingAnnotation, Expr, Literal, Pat, PatId, RecordFieldPat},
8     path::Path,
9     type_ref::Mutability,
10     FieldId,
11 };
12 use hir_expand::name::Name;
13 use test_utils::mark;
14
15 use super::{BindingMode, Expectation, InferenceContext};
16 use crate::{utils::variant_data, Substs, Ty, TypeCtor};
17
18 impl<'a> InferenceContext<'a> {
19     fn infer_tuple_struct_pat(
20         &mut self,
21         path: Option<&Path>,
22         subpats: &[PatId],
23         expected: &Ty,
24         default_bm: BindingMode,
25         id: PatId,
26     ) -> Ty {
27         let (ty, def) = self.resolve_variant(path);
28         let var_data = def.map(|it| variant_data(self.db.upcast(), it));
29         if let Some(variant) = def {
30             self.write_variant_resolution(id.into(), variant);
31         }
32         self.unify(&ty, expected);
33
34         let substs = ty.substs().unwrap_or_else(Substs::empty);
35
36         let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default();
37
38         for (i, &subpat) in subpats.iter().enumerate() {
39             let expected_ty = var_data
40                 .as_ref()
41                 .and_then(|d| d.field(&Name::new_tuple_field(i)))
42                 .map_or(Ty::Unknown, |field| field_tys[field].clone().subst(&substs));
43             let expected_ty = self.normalize_associated_types_in(expected_ty);
44             self.infer_pat(subpat, &expected_ty, default_bm);
45         }
46
47         ty
48     }
49
50     fn infer_record_pat(
51         &mut self,
52         path: Option<&Path>,
53         subpats: &[RecordFieldPat],
54         expected: &Ty,
55         default_bm: BindingMode,
56         id: PatId,
57     ) -> Ty {
58         let (ty, def) = self.resolve_variant(path);
59         let var_data = def.map(|it| variant_data(self.db.upcast(), it));
60         if let Some(variant) = def {
61             self.write_variant_resolution(id.into(), variant);
62         }
63
64         self.unify(&ty, expected);
65
66         let substs = ty.substs().unwrap_or_else(Substs::empty);
67
68         let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default();
69         for subpat in subpats {
70             let matching_field = var_data.as_ref().and_then(|it| it.field(&subpat.name));
71             if let Some(local_id) = matching_field {
72                 let field_def = FieldId { parent: def.unwrap(), local_id };
73                 self.result.record_pat_field_resolutions.insert(subpat.pat, field_def);
74             }
75
76             let expected_ty =
77                 matching_field.map_or(Ty::Unknown, |field| field_tys[field].clone().subst(&substs));
78             let expected_ty = self.normalize_associated_types_in(expected_ty);
79             self.infer_pat(subpat.pat, &expected_ty, default_bm);
80         }
81
82         ty
83     }
84
85     pub(super) fn infer_pat(
86         &mut self,
87         pat: PatId,
88         mut expected: &Ty,
89         mut default_bm: BindingMode,
90     ) -> Ty {
91         let body = Arc::clone(&self.body); // avoid borrow checker problem
92
93         if is_non_ref_pat(&body, pat) {
94             while let Some((inner, mutability)) = expected.as_reference() {
95                 expected = inner;
96                 default_bm = match default_bm {
97                     BindingMode::Move => BindingMode::Ref(mutability),
98                     BindingMode::Ref(Mutability::Shared) => BindingMode::Ref(Mutability::Shared),
99                     BindingMode::Ref(Mutability::Mut) => BindingMode::Ref(mutability),
100                 }
101             }
102         } else if let Pat::Ref { .. } = &body[pat] {
103             mark::hit!(match_ergonomics_ref);
104             // When you encounter a `&pat` pattern, reset to Move.
105             // This is so that `w` is by value: `let (_, &w) = &(1, &2);`
106             default_bm = BindingMode::Move;
107         }
108
109         // Lose mutability.
110         let default_bm = default_bm;
111         let expected = expected;
112
113         let ty = match &body[pat] {
114             Pat::Tuple { ref args, .. } => {
115                 let expectations = match expected.as_tuple() {
116                     Some(parameters) => &*parameters.0,
117                     _ => &[],
118                 };
119                 let expectations_iter = expectations.iter().chain(repeat(&Ty::Unknown));
120
121                 let inner_tys = args
122                     .iter()
123                     .zip(expectations_iter)
124                     .map(|(&pat, ty)| self.infer_pat(pat, ty, default_bm))
125                     .collect();
126
127                 Ty::apply(TypeCtor::Tuple { cardinality: args.len() as u16 }, Substs(inner_tys))
128             }
129             Pat::Or(ref pats) => {
130                 if let Some((first_pat, rest)) = pats.split_first() {
131                     let ty = self.infer_pat(*first_pat, expected, default_bm);
132                     for pat in rest {
133                         self.infer_pat(*pat, expected, default_bm);
134                     }
135                     ty
136                 } else {
137                     Ty::Unknown
138                 }
139             }
140             Pat::Ref { pat, mutability } => {
141                 let expectation = match expected.as_reference() {
142                     Some((inner_ty, exp_mut)) => {
143                         if *mutability != exp_mut {
144                             // FIXME: emit type error?
145                         }
146                         inner_ty
147                     }
148                     _ => &Ty::Unknown,
149                 };
150                 let subty = self.infer_pat(*pat, expectation, default_bm);
151                 Ty::apply_one(TypeCtor::Ref(*mutability), subty)
152             }
153             Pat::TupleStruct { path: p, args: subpats, .. } => {
154                 self.infer_tuple_struct_pat(p.as_ref(), subpats, expected, default_bm, pat)
155             }
156             Pat::Record { path: p, args: fields, ellipsis: _ } => {
157                 self.infer_record_pat(p.as_ref(), fields, expected, default_bm, pat)
158             }
159             Pat::Path(path) => {
160                 // FIXME use correct resolver for the surrounding expression
161                 let resolver = self.resolver.clone();
162                 self.infer_path(&resolver, &path, pat.into()).unwrap_or(Ty::Unknown)
163             }
164             Pat::Bind { mode, name: _, subpat } => {
165                 let mode = if mode == &BindingAnnotation::Unannotated {
166                     default_bm
167                 } else {
168                     BindingMode::convert(*mode)
169                 };
170                 let inner_ty = if let Some(subpat) = subpat {
171                     self.infer_pat(*subpat, expected, default_bm)
172                 } else {
173                     expected.clone()
174                 };
175                 let inner_ty = self.insert_type_vars_shallow(inner_ty);
176
177                 let bound_ty = match mode {
178                     BindingMode::Ref(mutability) => {
179                         Ty::apply_one(TypeCtor::Ref(mutability), inner_ty.clone())
180                     }
181                     BindingMode::Move => inner_ty.clone(),
182                 };
183                 let bound_ty = self.resolve_ty_as_possible(bound_ty);
184                 self.write_pat_ty(pat, bound_ty);
185                 return inner_ty;
186             }
187             Pat::Slice { prefix, slice, suffix } => {
188                 let (container_ty, elem_ty) = match &expected {
189                     ty_app!(TypeCtor::Array, st) => (TypeCtor::Array, st.as_single().clone()),
190                     ty_app!(TypeCtor::Slice, st) => (TypeCtor::Slice, st.as_single().clone()),
191                     _ => (TypeCtor::Slice, Ty::Unknown),
192                 };
193
194                 for pat_id in prefix.iter().chain(suffix) {
195                     self.infer_pat(*pat_id, &elem_ty, default_bm);
196                 }
197
198                 let pat_ty = Ty::apply_one(container_ty, elem_ty);
199                 if let Some(slice_pat_id) = slice {
200                     self.infer_pat(*slice_pat_id, &pat_ty, default_bm);
201                 }
202
203                 pat_ty
204             }
205             Pat::Wild => expected.clone(),
206             Pat::Range { start, end } => {
207                 let start_ty = self.infer_expr(*start, &Expectation::has_type(expected.clone()));
208                 let end_ty = self.infer_expr(*end, &Expectation::has_type(start_ty));
209                 end_ty
210             }
211             Pat::Lit(expr) => self.infer_expr(*expr, &Expectation::has_type(expected.clone())),
212             Pat::Box { inner } => match self.resolve_boxed_box() {
213                 Some(box_adt) => {
214                     let inner_expected = match expected.as_adt() {
215                         Some((adt, substs)) if adt == box_adt => substs.as_single(),
216                         _ => &Ty::Unknown,
217                     };
218
219                     let inner_ty = self.infer_pat(*inner, inner_expected, default_bm);
220                     Ty::apply_one(TypeCtor::Adt(box_adt), inner_ty)
221                 }
222                 None => Ty::Unknown,
223             },
224             Pat::Missing => Ty::Unknown,
225         };
226         // use a new type variable if we got Ty::Unknown here
227         let ty = self.insert_type_vars_shallow(ty);
228         if !self.unify(&ty, expected) {
229             // FIXME record mismatch, we need to change the type of self.type_mismatches for that
230         }
231         let ty = self.resolve_ty_as_possible(ty);
232         self.write_pat_ty(pat, ty.clone());
233         ty
234     }
235 }
236
237 fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool {
238     match &body[pat] {
239         Pat::Tuple { .. }
240         | Pat::TupleStruct { .. }
241         | Pat::Record { .. }
242         | Pat::Range { .. }
243         | Pat::Slice { .. } => true,
244         Pat::Or(pats) => pats.iter().all(|p| is_non_ref_pat(body, *p)),
245         // FIXME: Path/Lit might actually evaluate to ref, but inference is unimplemented.
246         Pat::Path(..) => true,
247         Pat::Lit(expr) => match body[*expr] {
248             Expr::Literal(Literal::String(..)) => false,
249             _ => true,
250         },
251         Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing => false,
252     }
253 }