]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs
094e460dbf79b0b08b35f2911c0bf83d78d8faca
[rust.git] / src / tools / rust-analyzer / crates / hir-ty / src / infer / closure.rs
1 //! Inference of closure parameter types based on the closure's expected type.
2
3 use chalk_ir::{cast::Cast, AliasEq, AliasTy, FnSubst, WhereClause};
4 use hir_def::{expr::ExprId, HasModule};
5 use smallvec::SmallVec;
6
7 use crate::{
8     to_chalk_trait_id, utils, ChalkTraitId, DynTy, FnPointer, FnSig, Interner, Substitution, Ty,
9     TyExt, TyKind,
10 };
11
12 use super::{Expectation, InferenceContext};
13
14 impl InferenceContext<'_> {
15     // This function handles both closures and generators.
16     pub(super) fn deduce_closure_type_from_expectations(
17         &mut self,
18         closure_expr: ExprId,
19         closure_ty: &Ty,
20         sig_ty: &Ty,
21         expectation: &Expectation,
22     ) {
23         let expected_ty = match expectation.to_option(&mut self.table) {
24             Some(ty) => ty,
25             None => return,
26         };
27
28         // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
29         let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
30
31         // Generators are not Fn* so return early.
32         if matches!(closure_ty.kind(Interner), TyKind::Generator(..)) {
33             return;
34         }
35
36         // Deduction based on the expected `dyn Fn` is done separately.
37         if let TyKind::Dyn(dyn_ty) = expected_ty.kind(Interner) {
38             if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) {
39                 let expected_sig_ty = TyKind::Function(sig).intern(Interner);
40
41                 self.unify(sig_ty, &expected_sig_ty);
42             }
43         }
44     }
45
46     fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
47         // Search for a predicate like `<$self as FnX<Args>>::Output == Ret`
48
49         let fn_traits: SmallVec<[ChalkTraitId; 3]> =
50             utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate())
51                 .map(to_chalk_trait_id)
52                 .collect();
53
54         let self_ty = TyKind::Error.intern(Interner);
55         let bounds = dyn_ty.bounds.clone().substitute(Interner, &[self_ty.cast(Interner)]);
56         for bound in bounds.iter(Interner) {
57             // NOTE(skip_binders): the extracted types are rebound by the returned `FnPointer`
58             if let WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(projection), ty }) =
59                 bound.skip_binders()
60             {
61                 let assoc_data = self.db.associated_ty_data(projection.associated_ty_id);
62                 if !fn_traits.contains(&assoc_data.trait_id) {
63                     return None;
64                 }
65
66                 // Skip `Self`, get the type argument.
67                 let arg = projection.substitution.as_slice(Interner).get(1)?;
68                 if let Some(subst) = arg.ty(Interner)?.as_tuple() {
69                     let generic_args = subst.as_slice(Interner);
70                     let mut sig_tys = Vec::new();
71                     for arg in generic_args {
72                         sig_tys.push(arg.ty(Interner)?.clone());
73                     }
74                     sig_tys.push(ty.clone());
75
76                     cov_mark::hit!(dyn_fn_param_informs_call_site_closure_signature);
77                     return Some(FnPointer {
78                         num_binders: bound.len(Interner),
79                         sig: FnSig { abi: (), safety: chalk_ir::Safety::Safe, variadic: false },
80                         substitution: FnSubst(Substitution::from_iter(Interner, sig_tys)),
81                     });
82                 }
83             }
84         }
85
86         None
87     }
88 }