]> git.lizzy.rs Git - rust.git/blob - crates/hir_ty/src/infer/closure.rs
Implement `CoerceMany`
[rust.git] / crates / hir_ty / src / infer / closure.rs
1 //! Inference of closure parameter types based on the closure's expected type.
2
3 use chalk_ir::{cast::Cast, AliasTy, FnSubst, WhereClause};
4 use hir_def::{expr::ExprId, HasModule};
5 use smallvec::SmallVec;
6
7 use crate::{
8     to_chalk_trait_id, utils, ChalkTraitId, DynTy, FnPointer, FnSig, Interner, Substitution, Ty,
9     TyExt, TyKind,
10 };
11
12 use super::{Expectation, InferenceContext};
13
14 impl InferenceContext<'_> {
15     pub(super) fn deduce_closure_type_from_expectations(
16         &mut self,
17         closure_expr: ExprId,
18         closure_ty: &Ty,
19         sig_ty: &Ty,
20         expectation: &Expectation,
21     ) {
22         let expected_ty = match expectation.to_option(&mut self.table) {
23             Some(ty) => ty,
24             None => return,
25         };
26
27         // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
28         let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
29
30         // Deduction based on the expected `dyn Fn` is done separately.
31         if let TyKind::Dyn(dyn_ty) = expected_ty.kind(&Interner) {
32             if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) {
33                 let expected_sig_ty = TyKind::Function(sig).intern(&Interner);
34
35                 self.unify(sig_ty, &expected_sig_ty);
36             }
37         }
38     }
39
40     fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
41         // Search for a predicate like `<$self as FnX<Args>>::Output == Ret`
42
43         let fn_traits: SmallVec<[ChalkTraitId; 3]> =
44             utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate())
45                 .map(|tid| to_chalk_trait_id(tid))
46                 .collect();
47
48         let self_ty = TyKind::Error.intern(&Interner);
49         let bounds = dyn_ty.bounds.clone().substitute(&Interner, &[self_ty.cast(&Interner)]);
50         for bound in bounds.iter(&Interner) {
51             // NOTE(skip_binders): the extracted types are rebound by the returned `FnPointer`
52             match bound.skip_binders() {
53                 WhereClause::AliasEq(eq) => match &eq.alias {
54                     AliasTy::Projection(projection) => {
55                         let assoc_data = self.db.associated_ty_data(projection.associated_ty_id);
56                         if !fn_traits.contains(&assoc_data.trait_id) {
57                             return None;
58                         }
59
60                         // Skip `Self`, get the type argument.
61                         let arg = projection.substitution.as_slice(&Interner).get(1)?;
62                         if let Some(subst) = arg.ty(&Interner)?.as_tuple() {
63                             let generic_args = subst.as_slice(&Interner);
64                             let mut sig_tys = Vec::new();
65                             for arg in generic_args {
66                                 sig_tys.push(arg.ty(&Interner)?.clone());
67                             }
68                             sig_tys.push(eq.ty.clone());
69
70                             cov_mark::hit!(dyn_fn_param_informs_call_site_closure_signature);
71                             return Some(FnPointer {
72                                 num_binders: bound.len(&Interner),
73                                 sig: FnSig {
74                                     abi: (),
75                                     safety: chalk_ir::Safety::Safe,
76                                     variadic: false,
77                                 },
78                                 substitution: FnSubst(Substitution::from_iter(
79                                     &Interner,
80                                     sig_tys.clone(),
81                                 )),
82                             });
83                         }
84                     }
85                     AliasTy::Opaque(_) => {}
86                 },
87                 _ => {}
88             }
89         }
90
91         None
92     }
93 }