]> git.lizzy.rs Git - rust.git/commitdiff
Implement type inference for generator and yield expressions
authorRyo Yoshida <low.ryoshida@gmail.com>
Tue, 6 Sep 2022 08:48:06 +0000 (17:48 +0900)
committerRyo Yoshida <low.ryoshida@gmail.com>
Mon, 12 Sep 2022 17:42:52 +0000 (02:42 +0900)
crates/hir-ty/src/db.rs
crates/hir-ty/src/infer.rs
crates/hir-ty/src/infer/closure.rs
crates/hir-ty/src/infer/expr.rs
crates/hir-ty/src/mapping.rs

index b385b1cafaefd09f3db05d293cb88fb0737e1f6b..c4f7685cd140a30609f889104b52fd3af24812dc 100644 (file)
@@ -116,6 +116,8 @@ fn intern_type_or_const_param_id(
     fn intern_impl_trait_id(&self, id: ImplTraitId) -> InternedOpaqueTyId;
     #[salsa::interned]
     fn intern_closure(&self, id: (DefWithBodyId, ExprId)) -> InternedClosureId;
+    #[salsa::interned]
+    fn intern_generator(&self, id: (DefWithBodyId, ExprId)) -> InternedGeneratorId;
 
     #[salsa::invoke(chalk_db::associated_ty_data_query)]
     fn associated_ty_data(&self, id: chalk_db::AssocTypeId) -> Arc<chalk_db::AssociatedTyDatum>;
@@ -218,6 +220,10 @@ fn _assert_object_safe(_: &dyn HirDatabase) {}
 pub struct InternedClosureId(salsa::InternId);
 impl_intern_key!(InternedClosureId);
 
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct InternedGeneratorId(salsa::InternId);
+impl_intern_key!(InternedGeneratorId);
+
 /// This exists just for Chalk, because Chalk just has a single `FnDefId` where
 /// we have different IDs for struct and enum variant constructors.
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
index 10ffde87eef1491b4c6777db2c19309ff6f36f4e..d351eb599ce468930af9187e335a29db79008f43 100644 (file)
@@ -332,7 +332,7 @@ pub struct InferenceResult {
     /// unresolved or missing subpatterns or subpatterns of mismatched types.
     pub type_of_pat: ArenaMap<PatId, Ty>,
     type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>,
-    /// Interned Unknown to return references to.
+    /// Interned common types to return references to.
     standard_types: InternedStandardTypes,
     /// Stores the types which were implicitly dereferenced in pattern binding modes.
     pub pat_adjustments: FxHashMap<PatId, Vec<Ty>>,
@@ -412,6 +412,8 @@ pub(crate) struct InferenceContext<'a> {
     /// closures, but currently this is the only field that will change there,
     /// so it doesn't make sense.
     return_ty: Ty,
+    /// The resume type and the yield type, respectively, of the generator being inferred.
+    resume_yield_tys: Option<(Ty, Ty)>,
     diverges: Diverges,
     breakables: Vec<BreakableContext>,
 }
@@ -476,6 +478,7 @@ fn new(
             table: unify::InferenceTable::new(db, trait_env.clone()),
             trait_env,
             return_ty: TyKind::Error.intern(Interner), // set in collect_fn_signature
+            resume_yield_tys: None,
             db,
             owner,
             body,
index 3ead929098bcc25843d92fcd6b28c09ac0f07f16..094e460dbf79b0b08b35f2911c0bf83d78d8faca 100644 (file)
@@ -12,6 +12,7 @@
 use super::{Expectation, InferenceContext};
 
 impl InferenceContext<'_> {
+    // This function handles both closures and generators.
     pub(super) fn deduce_closure_type_from_expectations(
         &mut self,
         closure_expr: ExprId,
@@ -27,6 +28,11 @@ pub(super) fn deduce_closure_type_from_expectations(
         // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
         let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
 
+        // Generators are not Fn* so return early.
+        if matches!(closure_ty.kind(Interner), TyKind::Generator(..)) {
+            return;
+        }
+
         // Deduction based on the expected `dyn Fn` is done separately.
         if let TyKind::Dyn(dyn_ty) = expected_ty.kind(Interner) {
             if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) {
index f0382846a6affcbe0c5f59f0771bc68e3d8e0052..e3d6be23e6586e33f1895ad81e650430552e4d53 100644 (file)
     cast::Cast, fold::Shift, DebruijnIndex, GenericArgData, Mutability, TyVariableKind,
 };
 use hir_def::{
-    expr::{ArithOp, Array, BinaryOp, CmpOp, Expr, ExprId, LabelId, Literal, Statement, UnaryOp},
+    expr::{
+        ArithOp, Array, BinaryOp, ClosureKind, CmpOp, Expr, ExprId, LabelId, Literal, Statement,
+        UnaryOp,
+    },
     generics::TypeOrConstParamData,
     path::{GenericArg, GenericArgs},
     resolver::resolver_for_expr,
@@ -216,7 +219,7 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                 self.diverges = Diverges::Maybe;
                 TyBuilder::unit()
             }
-            Expr::Closure { body, args, ret_type, arg_types, closure_kind: _ } => {
+            Expr::Closure { body, args, ret_type, arg_types, closure_kind } => {
                 assert_eq!(args.len(), arg_types.len());
 
                 let mut sig_tys = Vec::new();
@@ -244,20 +247,40 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                     ),
                 })
                 .intern(Interner);
-                let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
-                let closure_ty =
-                    TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone()))
-                        .intern(Interner);
+
+                let (ty, resume_yield_tys) = if matches!(closure_kind, ClosureKind::Generator(_)) {
+                    // FIXME: report error when there are more than 1 parameter.
+                    let resume_ty = match sig_tys.first() {
+                        // When `sig_tys.len() == 1` the first type is the return type, not the
+                        // first parameter type.
+                        Some(ty) if sig_tys.len() > 1 => ty.clone(),
+                        _ => self.result.standard_types.unit.clone(),
+                    };
+                    let yield_ty = self.table.new_type_var();
+
+                    let subst = TyBuilder::subst_for_generator(self.db, self.owner)
+                        .push(resume_ty.clone())
+                        .push(yield_ty.clone())
+                        .push(ret_ty.clone())
+                        .build();
+
+                    let generator_id = self.db.intern_generator((self.owner, tgt_expr)).into();
+                    let generator_ty = TyKind::Generator(generator_id, subst).intern(Interner);
+
+                    (generator_ty, Some((resume_ty, yield_ty)))
+                } else {
+                    let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
+                    let closure_ty =
+                        TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone()))
+                            .intern(Interner);
+
+                    (closure_ty, None)
+                };
 
                 // Eagerly try to relate the closure type with the expected
                 // type, otherwise we often won't have enough information to
                 // infer the body.
-                self.deduce_closure_type_from_expectations(
-                    tgt_expr,
-                    &closure_ty,
-                    &sig_ty,
-                    expected,
-                );
+                self.deduce_closure_type_from_expectations(tgt_expr, &ty, &sig_ty, expected);
 
                 // Now go through the argument patterns
                 for (arg_pat, arg_ty) in args.iter().zip(sig_tys) {
@@ -266,6 +289,8 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
 
                 let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
                 let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone());
+                let prev_resume_yield_tys =
+                    mem::replace(&mut self.resume_yield_tys, resume_yield_tys);
 
                 self.with_breakable_ctx(BreakableKind::Border, self.err_ty(), None, |this| {
                     this.infer_expr_coerce(*body, &Expectation::has_type(ret_ty));
@@ -273,8 +298,9 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
 
                 self.diverges = prev_diverges;
                 self.return_ty = prev_ret_ty;
+                self.resume_yield_tys = prev_resume_yield_tys;
 
-                closure_ty
+                ty
             }
             Expr::Call { callee, args, .. } => {
                 let callee_ty = self.infer_expr(*callee, &Expectation::none());
@@ -423,11 +449,18 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                 TyKind::Never.intern(Interner)
             }
             Expr::Yield { expr } => {
-                // FIXME: track yield type for coercion
-                if let Some(expr) = expr {
-                    self.infer_expr(*expr, &Expectation::none());
+                if let Some((resume_ty, yield_ty)) = self.resume_yield_tys.clone() {
+                    if let Some(expr) = expr {
+                        self.infer_expr_coerce(*expr, &Expectation::has_type(yield_ty));
+                    } else {
+                        let unit = self.result.standard_types.unit.clone();
+                        let _ = self.coerce(Some(tgt_expr), &unit, &yield_ty);
+                    }
+                    resume_ty
+                } else {
+                    // FIXME: report error (yield expr in non-generator)
+                    TyKind::Error.intern(Interner)
                 }
-                TyKind::Never.intern(Interner)
             }
             Expr::RecordLit { path, fields, spread, .. } => {
                 let (ty, def_id) = self.resolve_variant(path.as_deref(), false);
index d765fee0e1f4e70b523f26db36db9598b916b1f1..f80fb39c1f84e24854bfecbe3fd06a06786f86b2 100644 (file)
@@ -103,6 +103,18 @@ fn from(id: crate::db::InternedClosureId) -> Self {
     }
 }
 
+impl From<chalk_ir::GeneratorId<Interner>> for crate::db::InternedGeneratorId {
+    fn from(id: chalk_ir::GeneratorId<Interner>) -> Self {
+        Self::from_intern_id(id.0)
+    }
+}
+
+impl From<crate::db::InternedGeneratorId> for chalk_ir::GeneratorId<Interner> {
+    fn from(id: crate::db::InternedGeneratorId) -> Self {
+        chalk_ir::GeneratorId(id.as_intern_id())
+    }
+}
+
 pub fn to_foreign_def_id(id: TypeAliasId) -> ForeignDefId {
     chalk_ir::ForeignDefId(salsa::InternKey::as_intern_id(&id))
 }