]> git.lizzy.rs Git - rust.git/blobdiff - crates/hir_ty/src/infer/expr.rs
Merge #11842
[rust.git] / crates / hir_ty / src / infer / expr.rs
index e78a6377e5e6d99d1b892a46c39a21f2da73ba46..0b67f2c32e58d7e2c0336054f29d1c9d97999d30 100644 (file)
@@ -28,7 +28,7 @@
     lower::{
         const_or_path_to_chalk, generic_arg_to_chalk, lower_to_chalk_mutability, ParamLoweringMode,
     },
-    mapping::from_chalk,
+    mapping::{from_chalk, ToChalk},
     method_resolution,
     primitive::{self, UintTy},
     static_lifetime, to_chalk_trait_id,
@@ -67,14 +67,13 @@ pub(super) fn infer_expr_coerce(&mut self, expr: ExprId, expected: &Expectation)
         let ty = self.infer_expr_inner(expr, expected);
         if let Some(target) = expected.only_has_type(&mut self.table) {
             match self.coerce(Some(expr), &ty, &target) {
-                Ok(res) => res.value,
+                Ok(res) => res,
                 Err(_) => {
-                    self.result
-                        .type_mismatches
-                        .insert(expr.into(), TypeMismatch { expected: target, actual: ty.clone() });
-                    // Return actual type when type mismatch.
-                    // This is needed for diagnostic when return type mismatch.
-                    ty
+                    self.result.type_mismatches.insert(
+                        expr.into(),
+                        TypeMismatch { expected: target.clone(), actual: ty.clone() },
+                    );
+                    target
                 }
             }
         } else {
@@ -157,9 +156,17 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                 self.err_ty()
             }
             Expr::Async { body } => {
+                let ret_ty = self.table.new_type_var();
+                let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
+                let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone());
+
+                let inner_ty = self.infer_expr_coerce(*body, &Expectation::has_type(ret_ty));
+
+                self.diverges = prev_diverges;
+                self.return_ty = prev_ret_ty;
+
                 // Use the first type parameter as the output type of future.
                 // existential type AsyncBlockImplTrait<InnerType>: Future<Output = InnerType>
-                let inner_ty = self.infer_expr(*body, &Expectation::none());
                 let impl_trait_id = crate::ImplTraitId::AsyncBlockTypeImplTrait(self.owner, *body);
                 let opaque_ty_id = self.db.intern_impl_trait_id(impl_trait_id).into();
                 TyKind::OpaqueType(opaque_ty_id, Substitution::from1(Interner, inner_ty))
@@ -279,21 +286,29 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                 let callee_ty = self.infer_expr(*callee, &Expectation::none());
                 let mut derefs = Autoderef::new(&mut self.table, callee_ty.clone());
                 let mut res = None;
+                let mut derefed_callee = callee_ty.clone();
                 // manual loop to be able to access `derefs.table`
                 while let Some((callee_deref_ty, _)) = derefs.next() {
                     res = derefs.table.callable_sig(&callee_deref_ty, args.len());
                     if res.is_some() {
+                        derefed_callee = callee_deref_ty;
                         break;
                     }
                 }
-                let (param_tys, ret_ty): (Vec<Ty>, Ty) = match res {
+                // if the function is unresolved, we use is_varargs=true to
+                // suppress the arg count diagnostic here
+                let is_varargs =
+                    derefed_callee.callable_sig(self.db).map_or(false, |sig| sig.is_varargs)
+                        || res.is_none();
+                let (param_tys, ret_ty) = match res {
                     Some(res) => {
                         let adjustments = auto_deref_adjust_steps(&derefs);
                         self.write_expr_adj(*callee, adjustments);
                         res
                     }
-                    None => (Vec::new(), self.err_ty()),
+                    None => (Vec::new(), self.err_ty()), // FIXME diagnostic
                 };
+                let indices_to_skip = self.check_legacy_const_generics(derefed_callee, args);
                 self.register_obligations_for_call(&callee_ty);
 
                 let expected_inputs = self.expected_inputs_for_expected_output(
@@ -302,7 +317,14 @@ fn infer_expr_inner(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
                     param_tys.clone(),
                 );
 
-                self.check_call_arguments(args, &expected_inputs, &param_tys);
+                self.check_call_arguments(
+                    tgt_expr,
+                    args,
+                    &expected_inputs,
+                    &param_tys,
+                    &indices_to_skip,
+                    is_varargs,
+                );
                 self.normalize_associated_types_in(ret_ty)
             }
             Expr::MethodCall { receiver, args, method_name, generic_args } => self
@@ -891,9 +913,16 @@ fn infer_block(
                 self.table.new_maybe_never_var()
             } else {
                 if let Some(t) = expected.only_has_type(&mut self.table) {
-                    let _ = self.coerce(Some(expr), &TyBuilder::unit(), &t);
+                    if self.coerce(Some(expr), &TyBuilder::unit(), &t).is_err() {
+                        self.result.type_mismatches.insert(
+                            expr.into(),
+                            TypeMismatch { expected: t.clone(), actual: TyBuilder::unit() },
+                        );
+                    }
+                    t
+                } else {
+                    TyBuilder::unit()
                 }
-                TyBuilder::unit()
             }
         }
     }
@@ -937,22 +966,28 @@ fn infer_method_call(
         };
         let method_ty = method_ty.substitute(Interner, &substs);
         self.register_obligations_for_call(&method_ty);
-        let (formal_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
-            Some(sig) => {
-                if !sig.params().is_empty() {
-                    (sig.params()[0].clone(), sig.params()[1..].to_vec(), sig.ret().clone())
-                } else {
-                    (self.err_ty(), Vec::new(), sig.ret().clone())
+        let (formal_receiver_ty, param_tys, ret_ty, is_varargs) =
+            match method_ty.callable_sig(self.db) {
+                Some(sig) => {
+                    if !sig.params().is_empty() {
+                        (
+                            sig.params()[0].clone(),
+                            sig.params()[1..].to_vec(),
+                            sig.ret().clone(),
+                            sig.is_varargs,
+                        )
+                    } else {
+                        (self.err_ty(), Vec::new(), sig.ret().clone(), sig.is_varargs)
+                    }
                 }
-            }
-            None => (self.err_ty(), Vec::new(), self.err_ty()),
-        };
+                None => (self.err_ty(), Vec::new(), self.err_ty(), true),
+            };
         self.unify(&formal_receiver_ty, &receiver_ty);
 
         let expected_inputs =
             self.expected_inputs_for_expected_output(expected, ret_ty.clone(), param_tys.clone());
 
-        self.check_call_arguments(args, &expected_inputs, &param_tys);
+        self.check_call_arguments(tgt_expr, args, &expected_inputs, &param_tys, &[], is_varargs);
         self.normalize_associated_types_in(ret_ty)
     }
 
@@ -983,24 +1018,50 @@ fn expected_inputs_for_expected_output(
         }
     }
 
-    fn check_call_arguments(&mut self, args: &[ExprId], expected_inputs: &[Ty], param_tys: &[Ty]) {
+    fn check_call_arguments(
+        &mut self,
+        expr: ExprId,
+        args: &[ExprId],
+        expected_inputs: &[Ty],
+        param_tys: &[Ty],
+        skip_indices: &[u32],
+        is_varargs: bool,
+    ) {
+        if args.len() != param_tys.len() + skip_indices.len() && !is_varargs {
+            self.push_diagnostic(InferenceDiagnostic::MismatchedArgCount {
+                call_expr: expr,
+                expected: param_tys.len() + skip_indices.len(),
+                found: args.len(),
+            });
+        }
+
         // Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 --
         // We do this in a pretty awful way: first we type-check any arguments
         // that are not closures, then we type-check the closures. This is so
         // that we have more information about the types of arguments when we
         // type-check the functions. This isn't really the right way to do this.
         for &check_closures in &[false, true] {
+            let mut skip_indices = skip_indices.into_iter().copied().fuse().peekable();
             let param_iter = param_tys.iter().cloned().chain(repeat(self.err_ty()));
             let expected_iter = expected_inputs
                 .iter()
                 .cloned()
                 .chain(param_iter.clone().skip(expected_inputs.len()));
-            for ((&arg, param_ty), expected_ty) in args.iter().zip(param_iter).zip(expected_iter) {
+            for (idx, ((&arg, param_ty), expected_ty)) in
+                args.iter().zip(param_iter).zip(expected_iter).enumerate()
+            {
                 let is_closure = matches!(&self.body[arg], Expr::Lambda { .. });
                 if is_closure != check_closures {
                     continue;
                 }
 
+                while skip_indices.peek().map_or(false, |i| *i < idx as u32) {
+                    skip_indices.next();
+                }
+                if skip_indices.peek().copied() == Some(idx as u32) {
+                    continue;
+                }
+
                 // the difference between param_ty and expected here is that
                 // expected is the parameter when the expected *return* type is
                 // taken into account. So in `let _: &[i32] = identity(&[1, 2])`
@@ -1088,9 +1149,20 @@ fn substs_for_method_call(
                 }
             }
         };
-        let supplied_params = substs.len();
-        for _ in supplied_params..total_len {
-            substs.push(GenericArgData::Ty(self.table.new_type_var()).intern(Interner));
+        for (id, data) in def_generics.iter().skip(substs.len()) {
+            match data {
+                TypeOrConstParamData::TypeParamData(_) => {
+                    substs.push(GenericArgData::Ty(self.table.new_type_var()).intern(Interner))
+                }
+                TypeOrConstParamData::ConstParamData(_) => {
+                    substs.push(
+                        GenericArgData::Const(self.table.new_const_var(
+                            self.db.const_param_ty(ConstParamId::from_unchecked(id)),
+                        ))
+                        .intern(Interner),
+                    )
+                }
+            }
         }
         assert_eq!(substs.len(), total_len);
         Substitution::from_iter(Interner, substs)
@@ -1129,6 +1201,57 @@ fn register_obligations_for_call(&mut self, callable_ty: &Ty) {
         }
     }
 
+    /// Returns the argument indices to skip.
+    fn check_legacy_const_generics(&mut self, callee: Ty, args: &[ExprId]) -> Vec<u32> {
+        let (func, subst) = match callee.kind(Interner) {
+            TyKind::FnDef(fn_id, subst) => {
+                let callable = CallableDefId::from_chalk(self.db, *fn_id);
+                let func = match callable {
+                    CallableDefId::FunctionId(f) => f,
+                    _ => return Vec::new(),
+                };
+                (func, subst)
+            }
+            _ => return Vec::new(),
+        };
+
+        let data = self.db.function_data(func);
+        if data.legacy_const_generics_indices.is_empty() {
+            return Vec::new();
+        }
+
+        // only use legacy const generics if the param count matches with them
+        if data.params.len() + data.legacy_const_generics_indices.len() != args.len() {
+            if args.len() <= data.params.len() {
+                return Vec::new();
+            } else {
+                // there are more parameters than there should be without legacy
+                // const params; use them
+                let mut indices = data.legacy_const_generics_indices.clone();
+                indices.sort();
+                return indices;
+            }
+        }
+
+        // check legacy const parameters
+        for (subst_idx, arg_idx) in data.legacy_const_generics_indices.iter().copied().enumerate() {
+            let arg = match subst.at(Interner, subst_idx).constant(Interner) {
+                Some(c) => c,
+                None => continue, // not a const parameter?
+            };
+            if arg_idx >= args.len() as u32 {
+                continue;
+            }
+            let _ty = arg.data(Interner).ty.clone();
+            let expected = Expectation::none(); // FIXME use actual const ty, when that is lowered correctly
+            self.infer_expr(args[arg_idx as usize], &expected);
+            // FIXME: evaluate and unify with the const
+        }
+        let mut indices = data.legacy_const_generics_indices.clone();
+        indices.sort();
+        indices
+    }
+
     fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty) -> Option<Ty> {
         let lhs_ty = self.resolve_ty_shallow(&lhs_ty);
         let rhs_ty = self.resolve_ty_shallow(&rhs_ty);