]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_trait_selection/src/traits/util.rs
drive-by: use is_const and is_const_if_const
[rust.git] / compiler / rustc_trait_selection / src / traits / util.rs
1 use rustc_errors::DiagnosticBuilder;
2 use rustc_span::Span;
3 use smallvec::smallvec;
4 use smallvec::SmallVec;
5
6 use rustc_data_structures::fx::FxHashSet;
7 use rustc_hir::def_id::DefId;
8 use rustc_middle::ty::subst::{GenericArg, Subst, SubstsRef};
9 use rustc_middle::ty::{self, ToPredicate, Ty, TyCtxt, TypeFoldable};
10
11 use super::{Normalized, Obligation, ObligationCause, PredicateObligation, SelectionContext};
12 pub use rustc_infer::traits::{self, util::*};
13
14 use std::iter;
15
16 ///////////////////////////////////////////////////////////////////////////
17 // `TraitAliasExpander` iterator
18 ///////////////////////////////////////////////////////////////////////////
19
20 /// "Trait alias expansion" is the process of expanding a sequence of trait
21 /// references into another sequence by transitively following all trait
22 /// aliases. e.g. If you have bounds like `Foo + Send`, a trait alias
23 /// `trait Foo = Bar + Sync;`, and another trait alias
24 /// `trait Bar = Read + Write`, then the bounds would expand to
25 /// `Read + Write + Sync + Send`.
26 /// Expansion is done via a DFS (depth-first search), and the `visited` field
27 /// is used to avoid cycles.
28 pub struct TraitAliasExpander<'tcx> {
29     tcx: TyCtxt<'tcx>,
30     stack: Vec<TraitAliasExpansionInfo<'tcx>>,
31 }
32
33 /// Stores information about the expansion of a trait via a path of zero or more trait aliases.
34 #[derive(Debug, Clone)]
35 pub struct TraitAliasExpansionInfo<'tcx> {
36     pub path: SmallVec<[(ty::PolyTraitRef<'tcx>, Span); 4]>,
37 }
38
39 impl<'tcx> TraitAliasExpansionInfo<'tcx> {
40     fn new(trait_ref: ty::PolyTraitRef<'tcx>, span: Span) -> Self {
41         Self { path: smallvec![(trait_ref, span)] }
42     }
43
44     /// Adds diagnostic labels to `diag` for the expansion path of a trait through all intermediate
45     /// trait aliases.
46     pub fn label_with_exp_info(
47         &self,
48         diag: &mut DiagnosticBuilder<'_>,
49         top_label: &str,
50         use_desc: &str,
51     ) {
52         diag.span_label(self.top().1, top_label);
53         if self.path.len() > 1 {
54             for (_, sp) in self.path.iter().rev().skip(1).take(self.path.len() - 2) {
55                 diag.span_label(*sp, format!("referenced here ({})", use_desc));
56             }
57         }
58         if self.top().1 != self.bottom().1 {
59             // When the trait object is in a return type these two spans match, we don't want
60             // redundant labels.
61             diag.span_label(
62                 self.bottom().1,
63                 format!("trait alias used in trait object type ({})", use_desc),
64             );
65         }
66     }
67
68     pub fn trait_ref(&self) -> ty::PolyTraitRef<'tcx> {
69         self.top().0
70     }
71
72     pub fn top(&self) -> &(ty::PolyTraitRef<'tcx>, Span) {
73         self.path.last().unwrap()
74     }
75
76     pub fn bottom(&self) -> &(ty::PolyTraitRef<'tcx>, Span) {
77         self.path.first().unwrap()
78     }
79
80     fn clone_and_push(&self, trait_ref: ty::PolyTraitRef<'tcx>, span: Span) -> Self {
81         let mut path = self.path.clone();
82         path.push((trait_ref, span));
83
84         Self { path }
85     }
86 }
87
88 pub fn expand_trait_aliases<'tcx>(
89     tcx: TyCtxt<'tcx>,
90     trait_refs: impl Iterator<Item = (ty::PolyTraitRef<'tcx>, Span)>,
91 ) -> TraitAliasExpander<'tcx> {
92     let items: Vec<_> =
93         trait_refs.map(|(trait_ref, span)| TraitAliasExpansionInfo::new(trait_ref, span)).collect();
94     TraitAliasExpander { tcx, stack: items }
95 }
96
97 impl<'tcx> TraitAliasExpander<'tcx> {
98     /// If `item` is a trait alias and its predicate has not yet been visited, then expands `item`
99     /// to the definition, pushes the resulting expansion onto `self.stack`, and returns `false`.
100     /// Otherwise, immediately returns `true` if `item` is a regular trait, or `false` if it is a
101     /// trait alias.
102     /// The return value indicates whether `item` should be yielded to the user.
103     fn expand(&mut self, item: &TraitAliasExpansionInfo<'tcx>) -> bool {
104         let tcx = self.tcx;
105         let trait_ref = item.trait_ref();
106         let pred = trait_ref.without_const().to_predicate(tcx);
107
108         debug!("expand_trait_aliases: trait_ref={:?}", trait_ref);
109
110         // Don't recurse if this bound is not a trait alias.
111         let is_alias = tcx.is_trait_alias(trait_ref.def_id());
112         if !is_alias {
113             return true;
114         }
115
116         // Don't recurse if this trait alias is already on the stack for the DFS search.
117         let anon_pred = anonymize_predicate(tcx, pred);
118         if item.path.iter().rev().skip(1).any(|&(tr, _)| {
119             anonymize_predicate(tcx, tr.without_const().to_predicate(tcx)) == anon_pred
120         }) {
121             return false;
122         }
123
124         // Get components of trait alias.
125         let predicates = tcx.super_predicates_of(trait_ref.def_id());
126
127         let items = predicates.predicates.iter().rev().filter_map(|(pred, span)| {
128             pred.subst_supertrait(tcx, &trait_ref)
129                 .to_opt_poly_trait_pred()
130                 .map(|trait_ref| item.clone_and_push(trait_ref.map_bound(|t| t.trait_ref), *span))
131         });
132         debug!("expand_trait_aliases: items={:?}", items.clone());
133
134         self.stack.extend(items);
135
136         false
137     }
138 }
139
140 impl<'tcx> Iterator for TraitAliasExpander<'tcx> {
141     type Item = TraitAliasExpansionInfo<'tcx>;
142
143     fn size_hint(&self) -> (usize, Option<usize>) {
144         (self.stack.len(), None)
145     }
146
147     fn next(&mut self) -> Option<TraitAliasExpansionInfo<'tcx>> {
148         while let Some(item) = self.stack.pop() {
149             if self.expand(&item) {
150                 return Some(item);
151             }
152         }
153         None
154     }
155 }
156
157 ///////////////////////////////////////////////////////////////////////////
158 // Iterator over def-IDs of supertraits
159 ///////////////////////////////////////////////////////////////////////////
160
161 pub struct SupertraitDefIds<'tcx> {
162     tcx: TyCtxt<'tcx>,
163     stack: Vec<DefId>,
164     visited: FxHashSet<DefId>,
165 }
166
167 pub fn supertrait_def_ids(tcx: TyCtxt<'_>, trait_def_id: DefId) -> SupertraitDefIds<'_> {
168     SupertraitDefIds {
169         tcx,
170         stack: vec![trait_def_id],
171         visited: Some(trait_def_id).into_iter().collect(),
172     }
173 }
174
175 impl Iterator for SupertraitDefIds<'_> {
176     type Item = DefId;
177
178     fn next(&mut self) -> Option<DefId> {
179         let def_id = self.stack.pop()?;
180         let predicates = self.tcx.super_predicates_of(def_id);
181         let visited = &mut self.visited;
182         self.stack.extend(
183             predicates
184                 .predicates
185                 .iter()
186                 .filter_map(|(pred, _)| pred.to_opt_poly_trait_pred())
187                 .map(|trait_ref| trait_ref.def_id())
188                 .filter(|&super_def_id| visited.insert(super_def_id)),
189         );
190         Some(def_id)
191     }
192 }
193
194 ///////////////////////////////////////////////////////////////////////////
195 // Other
196 ///////////////////////////////////////////////////////////////////////////
197
198 /// Instantiate all bound parameters of the impl with the given substs,
199 /// returning the resulting trait ref and all obligations that arise.
200 /// The obligations are closed under normalization.
201 pub fn impl_trait_ref_and_oblig<'a, 'tcx>(
202     selcx: &mut SelectionContext<'a, 'tcx>,
203     param_env: ty::ParamEnv<'tcx>,
204     impl_def_id: DefId,
205     impl_substs: SubstsRef<'tcx>,
206 ) -> (ty::TraitRef<'tcx>, impl Iterator<Item = PredicateObligation<'tcx>>) {
207     let impl_trait_ref = selcx.tcx().impl_trait_ref(impl_def_id).unwrap();
208     let impl_trait_ref = impl_trait_ref.subst(selcx.tcx(), impl_substs);
209     let Normalized { value: impl_trait_ref, obligations: normalization_obligations1 } =
210         super::normalize(selcx, param_env, ObligationCause::dummy(), impl_trait_ref);
211
212     let predicates = selcx.tcx().predicates_of(impl_def_id);
213     let predicates = predicates.instantiate(selcx.tcx(), impl_substs);
214     let Normalized { value: predicates, obligations: normalization_obligations2 } =
215         super::normalize(selcx, param_env, ObligationCause::dummy(), predicates);
216     let impl_obligations =
217         predicates_for_generics(ObligationCause::dummy(), 0, param_env, predicates);
218
219     let impl_obligations = impl_obligations
220         .chain(normalization_obligations1.into_iter())
221         .chain(normalization_obligations2.into_iter());
222
223     (impl_trait_ref, impl_obligations)
224 }
225
226 pub fn predicates_for_generics<'tcx>(
227     cause: ObligationCause<'tcx>,
228     recursion_depth: usize,
229     param_env: ty::ParamEnv<'tcx>,
230     generic_bounds: ty::InstantiatedPredicates<'tcx>,
231 ) -> impl Iterator<Item = PredicateObligation<'tcx>> {
232     debug!("predicates_for_generics(generic_bounds={:?})", generic_bounds);
233
234     iter::zip(generic_bounds.predicates, generic_bounds.spans).map(move |(predicate, span)| {
235         let cause = match *cause.code() {
236             traits::ItemObligation(def_id) if !span.is_dummy() => traits::ObligationCause::new(
237                 cause.span,
238                 cause.body_id,
239                 traits::BindingObligation(def_id, span),
240             ),
241             _ => cause.clone(),
242         };
243         Obligation { cause, recursion_depth, param_env, predicate }
244     })
245 }
246
247 pub fn predicate_for_trait_ref<'tcx>(
248     tcx: TyCtxt<'tcx>,
249     cause: ObligationCause<'tcx>,
250     param_env: ty::ParamEnv<'tcx>,
251     trait_ref: ty::TraitRef<'tcx>,
252     recursion_depth: usize,
253 ) -> PredicateObligation<'tcx> {
254     Obligation {
255         cause,
256         param_env,
257         recursion_depth,
258         predicate: ty::Binder::dummy(trait_ref).without_const().to_predicate(tcx),
259     }
260 }
261
262 pub fn predicate_for_trait_def<'tcx>(
263     tcx: TyCtxt<'tcx>,
264     param_env: ty::ParamEnv<'tcx>,
265     cause: ObligationCause<'tcx>,
266     trait_def_id: DefId,
267     recursion_depth: usize,
268     self_ty: Ty<'tcx>,
269     params: &[GenericArg<'tcx>],
270 ) -> PredicateObligation<'tcx> {
271     let trait_ref =
272         ty::TraitRef { def_id: trait_def_id, substs: tcx.mk_substs_trait(self_ty, params) };
273     predicate_for_trait_ref(tcx, cause, param_env, trait_ref, recursion_depth)
274 }
275
276 /// Casts a trait reference into a reference to one of its super
277 /// traits; returns `None` if `target_trait_def_id` is not a
278 /// supertrait.
279 pub fn upcast_choices<'tcx>(
280     tcx: TyCtxt<'tcx>,
281     source_trait_ref: ty::PolyTraitRef<'tcx>,
282     target_trait_def_id: DefId,
283 ) -> Vec<ty::PolyTraitRef<'tcx>> {
284     if source_trait_ref.def_id() == target_trait_def_id {
285         return vec![source_trait_ref]; // Shortcut the most common case.
286     }
287
288     supertraits(tcx, source_trait_ref).filter(|r| r.def_id() == target_trait_def_id).collect()
289 }
290
291 /// Given a trait `trait_ref`, returns the number of vtable entries
292 /// that come from `trait_ref`, excluding its supertraits. Used in
293 /// computing the vtable base for an upcast trait of a trait object.
294 pub fn count_own_vtable_entries<'tcx>(
295     tcx: TyCtxt<'tcx>,
296     trait_ref: ty::PolyTraitRef<'tcx>,
297 ) -> usize {
298     let existential_trait_ref =
299         trait_ref.map_bound(|trait_ref| ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref));
300     let existential_trait_ref = tcx.erase_regions(existential_trait_ref);
301     tcx.own_existential_vtable_entries(existential_trait_ref).len()
302 }
303
304 /// Given an upcast trait object described by `object`, returns the
305 /// index of the method `method_def_id` (which should be part of
306 /// `object.upcast_trait_ref`) within the vtable for `object`.
307 pub fn get_vtable_index_of_object_method<'tcx, N>(
308     tcx: TyCtxt<'tcx>,
309     object: &super::ImplSourceObjectData<'tcx, N>,
310     method_def_id: DefId,
311 ) -> usize {
312     let existential_trait_ref = object
313         .upcast_trait_ref
314         .map_bound(|trait_ref| ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref));
315     let existential_trait_ref = tcx.erase_regions(existential_trait_ref);
316     // Count number of methods preceding the one we are selecting and
317     // add them to the total offset.
318     let index = tcx
319         .own_existential_vtable_entries(existential_trait_ref)
320         .iter()
321         .copied()
322         .position(|def_id| def_id == method_def_id)
323         .unwrap_or_else(|| {
324             bug!("get_vtable_index_of_object_method: {:?} was not found", method_def_id);
325         });
326     object.vtable_base + index
327 }
328
329 pub fn closure_trait_ref_and_return_type<'tcx>(
330     tcx: TyCtxt<'tcx>,
331     fn_trait_def_id: DefId,
332     self_ty: Ty<'tcx>,
333     sig: ty::PolyFnSig<'tcx>,
334     tuple_arguments: TupleArgumentsFlag,
335 ) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>)> {
336     let arguments_tuple = match tuple_arguments {
337         TupleArgumentsFlag::No => sig.skip_binder().inputs()[0],
338         TupleArgumentsFlag::Yes => tcx.intern_tup(sig.skip_binder().inputs()),
339     };
340     debug_assert!(!self_ty.has_escaping_bound_vars());
341     let trait_ref = ty::TraitRef {
342         def_id: fn_trait_def_id,
343         substs: tcx.mk_substs_trait(self_ty, &[arguments_tuple.into()]),
344     };
345     sig.map_bound(|sig| (trait_ref, sig.output()))
346 }
347
348 pub fn generator_trait_ref_and_outputs<'tcx>(
349     tcx: TyCtxt<'tcx>,
350     fn_trait_def_id: DefId,
351     self_ty: Ty<'tcx>,
352     sig: ty::PolyGenSig<'tcx>,
353 ) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>, Ty<'tcx>)> {
354     debug_assert!(!self_ty.has_escaping_bound_vars());
355     let trait_ref = ty::TraitRef {
356         def_id: fn_trait_def_id,
357         substs: tcx.mk_substs_trait(self_ty, &[sig.skip_binder().resume_ty.into()]),
358     };
359     sig.map_bound(|sig| (trait_ref, sig.yield_ty, sig.return_ty))
360 }
361
362 pub fn impl_item_is_final(tcx: TyCtxt<'_>, assoc_item: &ty::AssocItem) -> bool {
363     assoc_item.defaultness.is_final() && tcx.impl_defaultness(assoc_item.container.id()).is_final()
364 }
365
366 pub enum TupleArgumentsFlag {
367     Yes,
368     No,
369 }