impl_block: Option<ImplBlock>,
var_unification_table: InPlaceUnificationTable<TypeVarId>,
type_of: FxHashMap<LocalSyntaxPtr, Ty>,
+ /// The return type of the function being inferred.
+ return_ty: Ty,
}
impl<'a, D: HirDatabase> InferenceContext<'a, D> {
InferenceContext {
type_of: FxHashMap::default(),
var_unification_table: InPlaceUnificationTable::new(),
- self_param: None, // set during parameter typing
+ self_param: None, // set during parameter typing
+ return_ty: Ty::Unknown, // set in collect_fn_signature
db,
scopes,
module,
self.type_of.insert(LocalSyntaxPtr::new(node), ty);
}
+ fn make_ty(&self, type_ref: &TypeRef) -> Cancelable<Ty> {
+ Ty::from_hir(self.db, &self.module, self.impl_block.as_ref(), type_ref)
+ }
+
+ fn make_ty_opt(&self, type_ref: Option<&TypeRef>) -> Cancelable<Ty> {
+ Ty::from_hir_opt(self.db, &self.module, self.impl_block.as_ref(), type_ref)
+ }
+
fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
match (ty1, ty2) {
(Ty::Unknown, ..) => true,
self.write_ty(node.syntax(), ty.clone());
Ok(ty)
}
-}
-
-pub fn infer(db: &impl HirDatabase, def_id: DefId) -> Cancelable<Arc<InferenceResult>> {
- let function = Function::new(def_id); // TODO: consts also need inference
- let scopes = function.scopes(db);
- let module = function.module(db)?;
- let impl_block = function.impl_block(db)?;
- let mut ctx = InferenceContext::new(db, scopes, module, impl_block);
- let syntax = function.syntax(db);
- let node = syntax.borrowed();
-
- if let Some(param_list) = node.param_list() {
- if let Some(self_param) = param_list.self_param() {
- let self_type = if let Some(impl_block) = &ctx.impl_block {
- if let Some(type_ref) = self_param.type_ref() {
- let ty = Ty::from_ast(db, &ctx.module, ctx.impl_block.as_ref(), type_ref)?;
- ctx.insert_type_vars(ty)
+ fn collect_fn_signature(&mut self, node: ast::FnDef) -> Cancelable<()> {
+ if let Some(param_list) = node.param_list() {
+ if let Some(self_param) = param_list.self_param() {
+ let self_type = if let Some(type_ref) = self_param.type_ref() {
+ let ty = self.make_ty(&TypeRef::from_ast(type_ref))?;
+ self.insert_type_vars(ty)
} else {
// TODO this should be handled by desugaring during HIR conversion
- let ty = Ty::from_hir(
- db,
- &ctx.module,
- ctx.impl_block.as_ref(),
- impl_block.target(),
- )?;
+ let ty = self.make_ty_opt(self.impl_block.as_ref().map(|i| i.target()))?;
let ty = match self_param.flavor() {
ast::SelfParamFlavor::Owned => ty,
ast::SelfParamFlavor::Ref => Ty::Ref(Arc::new(ty), Mutability::Shared),
ast::SelfParamFlavor::MutRef => Ty::Ref(Arc::new(ty), Mutability::Mut),
};
- ctx.insert_type_vars(ty)
+ self.insert_type_vars(ty)
+ };
+ if let Some(self_kw) = self_param.self_kw() {
+ let self_param = LocalSyntaxPtr::new(self_kw.syntax());
+ self.self_param = Some(self_param);
+ self.type_of.insert(self_param, self_type);
}
- } else {
- log::debug!(
- "No impl block found, but self param for function {:?}",
- def_id
- );
- ctx.new_type_var()
- };
- if let Some(self_kw) = self_param.self_kw() {
- let self_param = LocalSyntaxPtr::new(self_kw.syntax());
- ctx.self_param = Some(self_param);
- ctx.type_of.insert(self_param, self_type);
+ }
+ for param in param_list.params() {
+ let pat = if let Some(pat) = param.pat() {
+ pat
+ } else {
+ continue;
+ };
+ let ty = if let Some(type_ref) = param.type_ref() {
+ let ty = self.make_ty(&TypeRef::from_ast(type_ref))?;
+ self.insert_type_vars(ty)
+ } else {
+ // missing type annotation
+ self.new_type_var()
+ };
+ self.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty);
}
}
- for param in param_list.params() {
- let pat = if let Some(pat) = param.pat() {
- pat
- } else {
- continue;
- };
- let ty = if let Some(type_ref) = param.type_ref() {
- let ty = Ty::from_ast(db, &ctx.module, ctx.impl_block.as_ref(), type_ref)?;
- ctx.insert_type_vars(ty)
- } else {
- // missing type annotation
- ctx.new_type_var()
- };
- ctx.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty);
- }
+
+ self.return_ty = if let Some(type_ref) = node.ret_type().and_then(|n| n.type_ref()) {
+ let ty = self.make_ty(&TypeRef::from_ast(type_ref))?;
+ self.insert_type_vars(ty)
+ } else {
+ Ty::unit()
+ };
+
+ Ok(())
}
+}
- let ret_ty = if let Some(type_ref) = node.ret_type().and_then(|n| n.type_ref()) {
- let ty = Ty::from_ast(db, &ctx.module, ctx.impl_block.as_ref(), type_ref)?;
- ctx.insert_type_vars(ty)
- } else {
- Ty::unit()
- };
+pub fn infer(db: &impl HirDatabase, def_id: DefId) -> Cancelable<Arc<InferenceResult>> {
+ let function = Function::new(def_id); // TODO: consts also need inference
+ let scopes = function.scopes(db);
+ let module = function.module(db)?;
+ let impl_block = function.impl_block(db)?;
+ let mut ctx = InferenceContext::new(db, scopes, module, impl_block);
+
+ let syntax = function.syntax(db);
+ let node = syntax.borrowed();
+
+ ctx.collect_fn_signature(node)?;
if let Some(block) = node.body() {
- ctx.infer_block(block, &Expectation::has_type(ret_ty))?;
+ ctx.infer_block(block, &Expectation::has_type(ctx.return_ty.clone()))?;
}
Ok(Arc::new(ctx.resolve_all()))