]> git.lizzy.rs Git - rust.git/blob - crates/hir_ty/src/utils.rs
Avoid type inference panic on bitslice methods
[rust.git] / crates / hir_ty / src / utils.rs
1 //! Helper functions for working with def, which don't need to be a separate
2 //! query, but can't be computed directly from `*Data` (ie, which need a `db`).
3
4 use std::{array, iter};
5
6 use base_db::CrateId;
7 use chalk_ir::{fold::Shift, BoundVar, DebruijnIndex};
8 use hir_def::{
9     db::DefDatabase,
10     generics::{
11         GenericParams, TypeParamData, TypeParamProvenance, WherePredicate, WherePredicateTypeTarget,
12     },
13     intern::Interned,
14     path::Path,
15     resolver::{HasResolver, TypeNs},
16     type_ref::{TraitBoundModifier, TypeRef},
17     AssocContainerId, GenericDefId, Lookup, TraitId, TypeAliasId, TypeParamId,
18 };
19 use hir_expand::name::{name, Name};
20 use rustc_hash::FxHashSet;
21
22 use crate::{
23     db::HirDatabase, ChalkTraitId, Interner, Substitution, TraitRef, TraitRefExt, TyKind,
24     WhereClause,
25 };
26
27 pub(crate) fn fn_traits(db: &dyn DefDatabase, krate: CrateId) -> impl Iterator<Item = TraitId> {
28     let fn_traits = [
29         db.lang_item(krate, "fn".into()),
30         db.lang_item(krate, "fn_mut".into()),
31         db.lang_item(krate, "fn_once".into()),
32     ];
33     array::IntoIter::new(fn_traits).into_iter().flatten().flat_map(|it| it.as_trait())
34 }
35
36 fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
37     let resolver = trait_.resolver(db);
38     // returning the iterator directly doesn't easily work because of
39     // lifetime problems, but since there usually shouldn't be more than a
40     // few direct traits this should be fine (we could even use some kind of
41     // SmallVec if performance is a concern)
42     let generic_params = db.generic_params(trait_.into());
43     let trait_self = generic_params.find_trait_self_param();
44     generic_params
45         .where_predicates
46         .iter()
47         .filter_map(|pred| match pred {
48             WherePredicate::ForLifetime { target, bound, .. }
49             | WherePredicate::TypeBound { target, bound } => match target {
50                 WherePredicateTypeTarget::TypeRef(type_ref) => match &**type_ref {
51                     TypeRef::Path(p) if p == &Path::from(name![Self]) => bound.as_path(),
52                     _ => None,
53                 },
54                 WherePredicateTypeTarget::TypeParam(local_id) if Some(*local_id) == trait_self => {
55                     bound.as_path()
56                 }
57                 _ => None,
58             },
59             WherePredicate::Lifetime { .. } => None,
60         })
61         .filter_map(|(path, bound_modifier)| match bound_modifier {
62             TraitBoundModifier::None => Some(path),
63             TraitBoundModifier::Maybe => None,
64         })
65         .filter_map(|path| match resolver.resolve_path_in_type_ns_fully(db, path.mod_path()) {
66             Some(TypeNs::TraitId(t)) => Some(t),
67             _ => None,
68         })
69         .collect()
70 }
71
72 fn direct_super_trait_refs(db: &dyn HirDatabase, trait_ref: &TraitRef) -> Vec<TraitRef> {
73     // returning the iterator directly doesn't easily work because of
74     // lifetime problems, but since there usually shouldn't be more than a
75     // few direct traits this should be fine (we could even use some kind of
76     // SmallVec if performance is a concern)
77     let generic_params = db.generic_params(trait_ref.hir_trait_id().into());
78     let trait_self = match generic_params.find_trait_self_param() {
79         Some(p) => TypeParamId { parent: trait_ref.hir_trait_id().into(), local_id: p },
80         None => return Vec::new(),
81     };
82     db.generic_predicates_for_param(trait_self)
83         .iter()
84         .filter_map(|pred| {
85             pred.as_ref().filter_map(|pred| match pred.skip_binders() {
86                 // FIXME: how to correctly handle higher-ranked bounds here?
87                 WhereClause::Implemented(tr) => Some(
88                     tr.clone()
89                         .shifted_out_to(&Interner, DebruijnIndex::ONE)
90                         .expect("FIXME unexpected higher-ranked trait bound"),
91                 ),
92                 _ => None,
93             })
94         })
95         .map(|pred| pred.substitute(&Interner, &trait_ref.substitution))
96         .collect()
97 }
98
99 /// Returns an iterator over the whole super trait hierarchy (including the
100 /// trait itself).
101 pub fn all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
102     // we need to take care a bit here to avoid infinite loops in case of cycles
103     // (i.e. if we have `trait A: B; trait B: A;`)
104     let mut result = vec![trait_];
105     let mut i = 0;
106     while i < result.len() {
107         let t = result[i];
108         // yeah this is quadratic, but trait hierarchies should be flat
109         // enough that this doesn't matter
110         for tt in direct_super_traits(db, t) {
111             if !result.contains(&tt) {
112                 result.push(tt);
113             }
114         }
115         i += 1;
116     }
117     result
118 }
119
120 /// Given a trait ref (`Self: Trait`), builds all the implied trait refs for
121 /// super traits. The original trait ref will be included. So the difference to
122 /// `all_super_traits` is that we keep track of type parameters; for example if
123 /// we have `Self: Trait<u32, i32>` and `Trait<T, U>: OtherTrait<U>` we'll get
124 /// `Self: OtherTrait<i32>`.
125 pub(super) fn all_super_trait_refs(db: &dyn HirDatabase, trait_ref: TraitRef) -> SuperTraits {
126     SuperTraits { db, seen: iter::once(trait_ref.trait_id).collect(), stack: vec![trait_ref] }
127 }
128
129 pub(super) struct SuperTraits<'a> {
130     db: &'a dyn HirDatabase,
131     stack: Vec<TraitRef>,
132     seen: FxHashSet<ChalkTraitId>,
133 }
134
135 impl<'a> SuperTraits<'a> {
136     fn elaborate(&mut self, trait_ref: &TraitRef) {
137         let mut trait_refs = direct_super_trait_refs(self.db, trait_ref);
138         trait_refs.retain(|tr| !self.seen.contains(&tr.trait_id));
139         self.stack.extend(trait_refs);
140     }
141 }
142
143 impl<'a> Iterator for SuperTraits<'a> {
144     type Item = TraitRef;
145
146     fn next(&mut self) -> Option<Self::Item> {
147         if let Some(next) = self.stack.pop() {
148             self.elaborate(&next);
149             Some(next)
150         } else {
151             None
152         }
153     }
154 }
155
156 pub(super) fn associated_type_by_name_including_super_traits(
157     db: &dyn HirDatabase,
158     trait_ref: TraitRef,
159     name: &Name,
160 ) -> Option<(TraitRef, TypeAliasId)> {
161     all_super_trait_refs(db, trait_ref).find_map(|t| {
162         let assoc_type = db.trait_data(t.hir_trait_id()).associated_type_by_name(name)?;
163         Some((t, assoc_type))
164     })
165 }
166
167 pub(crate) fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics {
168     let parent_generics = parent_generic_def(db, def).map(|def| Box::new(generics(db, def)));
169     Generics { def, params: db.generic_params(def), parent_generics }
170 }
171
172 #[derive(Debug)]
173 pub(crate) struct Generics {
174     def: GenericDefId,
175     pub(crate) params: Interned<GenericParams>,
176     parent_generics: Option<Box<Generics>>,
177 }
178
179 impl Generics {
180     pub(crate) fn iter<'a>(
181         &'a self,
182     ) -> impl Iterator<Item = (TypeParamId, &'a TypeParamData)> + 'a {
183         self.parent_generics
184             .as_ref()
185             .into_iter()
186             .flat_map(|it| {
187                 it.params
188                     .types
189                     .iter()
190                     .map(move |(local_id, p)| (TypeParamId { parent: it.def, local_id }, p))
191             })
192             .chain(
193                 self.params
194                     .types
195                     .iter()
196                     .map(move |(local_id, p)| (TypeParamId { parent: self.def, local_id }, p)),
197             )
198     }
199
200     pub(crate) fn iter_parent<'a>(
201         &'a self,
202     ) -> impl Iterator<Item = (TypeParamId, &'a TypeParamData)> + 'a {
203         self.parent_generics.as_ref().into_iter().flat_map(|it| {
204             it.params
205                 .types
206                 .iter()
207                 .map(move |(local_id, p)| (TypeParamId { parent: it.def, local_id }, p))
208         })
209     }
210
211     pub(crate) fn len(&self) -> usize {
212         self.len_split().0
213     }
214
215     /// (total, parents, child)
216     pub(crate) fn len_split(&self) -> (usize, usize, usize) {
217         let parent = self.parent_generics.as_ref().map_or(0, |p| p.len());
218         let child = self.params.types.len();
219         (parent + child, parent, child)
220     }
221
222     /// (parent total, self param, type param list, impl trait)
223     pub(crate) fn provenance_split(&self) -> (usize, usize, usize, usize) {
224         let parent = self.parent_generics.as_ref().map_or(0, |p| p.len());
225         let self_params = self
226             .params
227             .types
228             .iter()
229             .filter(|(_, p)| p.provenance == TypeParamProvenance::TraitSelf)
230             .count();
231         let list_params = self
232             .params
233             .types
234             .iter()
235             .filter(|(_, p)| p.provenance == TypeParamProvenance::TypeParamList)
236             .count();
237         let impl_trait_params = self
238             .params
239             .types
240             .iter()
241             .filter(|(_, p)| p.provenance == TypeParamProvenance::ArgumentImplTrait)
242             .count();
243         (parent, self_params, list_params, impl_trait_params)
244     }
245
246     pub(crate) fn param_idx(&self, param: TypeParamId) -> Option<usize> {
247         Some(self.find_param(param)?.0)
248     }
249
250     fn find_param(&self, param: TypeParamId) -> Option<(usize, &TypeParamData)> {
251         if param.parent == self.def {
252             let (idx, (_local_id, data)) = self
253                 .params
254                 .types
255                 .iter()
256                 .enumerate()
257                 .find(|(_, (idx, _))| *idx == param.local_id)
258                 .unwrap();
259             let (_total, parent_len, _child) = self.len_split();
260             Some((parent_len + idx, data))
261         } else {
262             self.parent_generics.as_ref().and_then(|g| g.find_param(param))
263         }
264     }
265
266     /// Returns a Substitution that replaces each parameter by a bound variable.
267     pub(crate) fn bound_vars_subst(&self, debruijn: DebruijnIndex) -> Substitution {
268         Substitution::from_iter(
269             &Interner,
270             self.iter()
271                 .enumerate()
272                 .map(|(idx, _)| TyKind::BoundVar(BoundVar::new(debruijn, idx)).intern(&Interner)),
273         )
274     }
275
276     /// Returns a Substitution that replaces each parameter by itself (i.e. `Ty::Param`).
277     pub(crate) fn type_params_subst(&self, db: &dyn HirDatabase) -> Substitution {
278         Substitution::from_iter(
279             &Interner,
280             self.iter().map(|(id, _)| {
281                 TyKind::Placeholder(crate::to_placeholder_idx(db, id)).intern(&Interner)
282             }),
283         )
284     }
285 }
286
287 fn parent_generic_def(db: &dyn DefDatabase, def: GenericDefId) -> Option<GenericDefId> {
288     let container = match def {
289         GenericDefId::FunctionId(it) => it.lookup(db).container,
290         GenericDefId::TypeAliasId(it) => it.lookup(db).container,
291         GenericDefId::ConstId(it) => it.lookup(db).container,
292         GenericDefId::EnumVariantId(it) => return Some(it.parent.into()),
293         GenericDefId::AdtId(_) | GenericDefId::TraitId(_) | GenericDefId::ImplId(_) => return None,
294     };
295
296     match container {
297         AssocContainerId::ImplId(it) => Some(it.into()),
298         AssocContainerId::TraitId(it) => Some(it.into()),
299         AssocContainerId::ModuleId(_) => None,
300     }
301 }