]> git.lizzy.rs Git - rust.git/blobdiff - compiler/rustc_hir_typeck/src/fn_ctxt/adjust_fulfillment_errors.rs
Split fn_ctxt/adjust_fulfillment_errors from fn_ctxt/checks
[rust.git] / compiler / rustc_hir_typeck / src / fn_ctxt / adjust_fulfillment_errors.rs
index 2eab68050d43043d6395205cbb863cb908445677..db1acb599271696008ec0143c1f68a587ddf173f 100644 (file)
 use crate::FnCtxt;
 use rustc_hir as hir;
 use rustc_hir::def::Res;
-use rustc_middle::ty::{self, DefIdTree, Ty};
+use rustc_hir::def_id::DefId;
+use rustc_infer::traits::ObligationCauseCode;
+use rustc_middle::ty::{self, DefIdTree, Ty, TypeSuperVisitable, TypeVisitable, TypeVisitor};
+use rustc_span::{self, Span};
 use rustc_trait_selection::traits;
 
+use std::ops::ControlFlow;
+
 impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
+    pub fn adjust_fulfillment_error_for_expr_obligation(
+        &self,
+        error: &mut traits::FulfillmentError<'tcx>,
+    ) -> bool {
+        let (traits::ExprItemObligation(def_id, hir_id, idx) | traits::ExprBindingObligation(def_id, _, hir_id, idx))
+            = *error.obligation.cause.code().peel_derives() else { return false; };
+        let hir = self.tcx.hir();
+        let hir::Node::Expr(expr) = hir.get(hir_id) else { return false; };
+
+        let Some(unsubstituted_pred) =
+            self.tcx.predicates_of(def_id).instantiate_identity(self.tcx).predicates.into_iter().nth(idx)
+            else { return false; };
+
+        let generics = self.tcx.generics_of(def_id);
+        let predicate_substs = match unsubstituted_pred.kind().skip_binder() {
+            ty::PredicateKind::Clause(ty::Clause::Trait(pred)) => pred.trait_ref.substs,
+            ty::PredicateKind::Clause(ty::Clause::Projection(pred)) => pred.projection_ty.substs,
+            _ => ty::List::empty(),
+        };
+
+        let find_param_matching = |matches: &dyn Fn(&ty::ParamTy) -> bool| {
+            predicate_substs.types().find_map(|ty| {
+                ty.walk().find_map(|arg| {
+                    if let ty::GenericArgKind::Type(ty) = arg.unpack()
+                        && let ty::Param(param_ty) = ty.kind()
+                        && matches(param_ty)
+                    {
+                        Some(arg)
+                    } else {
+                        None
+                    }
+                })
+            })
+        };
+
+        // Prefer generics that are local to the fn item, since these are likely
+        // to be the cause of the unsatisfied predicate.
+        let mut param_to_point_at = find_param_matching(&|param_ty| {
+            self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) == def_id
+        });
+        // Fall back to generic that isn't local to the fn item. This will come
+        // from a trait or impl, for example.
+        let mut fallback_param_to_point_at = find_param_matching(&|param_ty| {
+            self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) != def_id
+                && param_ty.name != rustc_span::symbol::kw::SelfUpper
+        });
+        // Finally, the `Self` parameter is possibly the reason that the predicate
+        // is unsatisfied. This is less likely to be true for methods, because
+        // method probe means that we already kinda check that the predicates due
+        // to the `Self` type are true.
+        let mut self_param_to_point_at =
+            find_param_matching(&|param_ty| param_ty.name == rustc_span::symbol::kw::SelfUpper);
+
+        // Finally, for ambiguity-related errors, we actually want to look
+        // for a parameter that is the source of the inference type left
+        // over in this predicate.
+        if let traits::FulfillmentErrorCode::CodeAmbiguity = error.code {
+            fallback_param_to_point_at = None;
+            self_param_to_point_at = None;
+            param_to_point_at =
+                self.find_ambiguous_parameter_in(def_id, error.root_obligation.predicate);
+        }
+
+        if self.closure_span_overlaps_error(error, expr.span) {
+            return false;
+        }
+
+        match &expr.kind {
+            hir::ExprKind::Path(qpath) => {
+                if let hir::Node::Expr(hir::Expr {
+                    kind: hir::ExprKind::Call(callee, args),
+                    hir_id: call_hir_id,
+                    span: call_span,
+                    ..
+                }) = hir.get_parent(expr.hir_id)
+                    && callee.hir_id == expr.hir_id
+                {
+                    if self.closure_span_overlaps_error(error, *call_span) {
+                        return false;
+                    }
+
+                    for param in
+                        [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
+                        .into_iter()
+                        .flatten()
+                    {
+                        if self.blame_specific_arg_if_possible(
+                                error,
+                                def_id,
+                                param,
+                                *call_hir_id,
+                                callee.span,
+                                None,
+                                args,
+                            )
+                        {
+                            return true;
+                        }
+                    }
+                }
+                // Notably, we only point to params that are local to the
+                // item we're checking, since those are the ones we are able
+                // to look in the final `hir::PathSegment` for. Everything else
+                // would require a deeper search into the `qpath` than I think
+                // is worthwhile.
+                if let Some(param_to_point_at) = param_to_point_at
+                    && self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
+                {
+                    return true;
+                }
+            }
+            hir::ExprKind::MethodCall(segment, receiver, args, ..) => {
+                for param in [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
+                    .into_iter()
+                    .flatten()
+                {
+                    if self.blame_specific_arg_if_possible(
+                        error,
+                        def_id,
+                        param,
+                        hir_id,
+                        segment.ident.span,
+                        Some(receiver),
+                        args,
+                    ) {
+                        return true;
+                    }
+                }
+                if let Some(param_to_point_at) = param_to_point_at
+                    && self.point_at_generic_if_possible(error, def_id, param_to_point_at, segment)
+                {
+                    return true;
+                }
+            }
+            hir::ExprKind::Struct(qpath, fields, ..) => {
+                if let Res::Def(
+                    hir::def::DefKind::Struct | hir::def::DefKind::Variant,
+                    variant_def_id,
+                ) = self.typeck_results.borrow().qpath_res(qpath, hir_id)
+                {
+                    for param in
+                        [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
+                    {
+                        if let Some(param) = param {
+                            let refined_expr = self.point_at_field_if_possible(
+                                def_id,
+                                param,
+                                variant_def_id,
+                                fields,
+                            );
+
+                            match refined_expr {
+                                None => {}
+                                Some((refined_expr, _)) => {
+                                    error.obligation.cause.span = refined_expr
+                                        .span
+                                        .find_ancestor_in_same_ctxt(error.obligation.cause.span)
+                                        .unwrap_or(refined_expr.span);
+                                    return true;
+                                }
+                            }
+                        }
+                    }
+                }
+                if let Some(param_to_point_at) = param_to_point_at
+                    && self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
+                {
+                    return true;
+                }
+            }
+            _ => {}
+        }
+
+        false
+    }
+
+    fn point_at_path_if_possible(
+        &self,
+        error: &mut traits::FulfillmentError<'tcx>,
+        def_id: DefId,
+        param: ty::GenericArg<'tcx>,
+        qpath: &hir::QPath<'tcx>,
+    ) -> bool {
+        match qpath {
+            hir::QPath::Resolved(_, path) => {
+                if let Some(segment) = path.segments.last()
+                    && self.point_at_generic_if_possible(error, def_id, param, segment)
+                {
+                    return true;
+                }
+            }
+            hir::QPath::TypeRelative(_, segment) => {
+                if self.point_at_generic_if_possible(error, def_id, param, segment) {
+                    return true;
+                }
+            }
+            _ => {}
+        }
+
+        false
+    }
+
+    fn point_at_generic_if_possible(
+        &self,
+        error: &mut traits::FulfillmentError<'tcx>,
+        def_id: DefId,
+        param_to_point_at: ty::GenericArg<'tcx>,
+        segment: &hir::PathSegment<'tcx>,
+    ) -> bool {
+        let own_substs = self
+            .tcx
+            .generics_of(def_id)
+            .own_substs(ty::InternalSubsts::identity_for_item(self.tcx, def_id));
+        let Some((index, _)) = own_substs
+            .iter()
+            .filter(|arg| matches!(arg.unpack(), ty::GenericArgKind::Type(_)))
+            .enumerate()
+            .find(|(_, arg)| **arg == param_to_point_at) else { return false };
+        let Some(arg) = segment
+            .args()
+            .args
+            .iter()
+            .filter(|arg| matches!(arg, hir::GenericArg::Type(_)))
+            .nth(index) else { return false; };
+        error.obligation.cause.span = arg
+            .span()
+            .find_ancestor_in_same_ctxt(error.obligation.cause.span)
+            .unwrap_or(arg.span());
+        true
+    }
+
+    fn find_ambiguous_parameter_in<T: TypeVisitable<'tcx>>(
+        &self,
+        item_def_id: DefId,
+        t: T,
+    ) -> Option<ty::GenericArg<'tcx>> {
+        struct FindAmbiguousParameter<'a, 'tcx>(&'a FnCtxt<'a, 'tcx>, DefId);
+        impl<'tcx> TypeVisitor<'tcx> for FindAmbiguousParameter<'_, 'tcx> {
+            type BreakTy = ty::GenericArg<'tcx>;
+            fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
+                if let Some(origin) = self.0.type_var_origin(ty)
+                    && let rustc_infer::infer::type_variable::TypeVariableOriginKind::TypeParameterDefinition(_, Some(def_id)) =
+                        origin.kind
+                    && let generics = self.0.tcx.generics_of(self.1)
+                    && let Some(index) = generics.param_def_id_to_index(self.0.tcx, def_id)
+                    && let Some(subst) = ty::InternalSubsts::identity_for_item(self.0.tcx, self.1)
+                        .get(index as usize)
+                {
+                    ControlFlow::Break(*subst)
+                } else {
+                    ty.super_visit_with(self)
+                }
+            }
+        }
+        t.visit_with(&mut FindAmbiguousParameter(self, item_def_id)).break_value()
+    }
+
+    fn closure_span_overlaps_error(
+        &self,
+        error: &traits::FulfillmentError<'tcx>,
+        span: Span,
+    ) -> bool {
+        if let traits::FulfillmentErrorCode::CodeSelectionError(
+            traits::SelectionError::OutputTypeParameterMismatch(_, expected, _),
+        ) = error.code
+            && let ty::Closure(def_id, _) | ty::Generator(def_id, ..) = expected.skip_binder().self_ty().kind()
+            && span.overlaps(self.tcx.def_span(*def_id))
+        {
+            true
+        } else {
+            false
+        }
+    }
+
+    fn point_at_field_if_possible(
+        &self,
+        def_id: DefId,
+        param_to_point_at: ty::GenericArg<'tcx>,
+        variant_def_id: DefId,
+        expr_fields: &[hir::ExprField<'tcx>],
+    ) -> Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)> {
+        let def = self.tcx.adt_def(def_id);
+
+        let identity_substs = ty::InternalSubsts::identity_for_item(self.tcx, def_id);
+        let fields_referencing_param: Vec<_> = def
+            .variant_with_id(variant_def_id)
+            .fields
+            .iter()
+            .filter(|field| {
+                let field_ty = field.ty(self.tcx, identity_substs);
+                Self::find_param_in_ty(field_ty.into(), param_to_point_at)
+            })
+            .collect();
+
+        if let [field] = fields_referencing_param.as_slice() {
+            for expr_field in expr_fields {
+                // Look for the ExprField that matches the field, using the
+                // same rules that check_expr_struct uses for macro hygiene.
+                if self.tcx.adjust_ident(expr_field.ident, variant_def_id) == field.ident(self.tcx)
+                {
+                    return Some((expr_field.expr, self.tcx.type_of(field.did)));
+                }
+            }
+        }
+
+        None
+    }
+
+    /// - `blame_specific_*` means that the function will recursively traverse the expression,
+    /// looking for the most-specific-possible span to blame.
+    ///
+    /// - `point_at_*` means that the function will only go "one level", pointing at the specific
+    /// expression mentioned.
+    ///
+    /// `blame_specific_arg_if_possible` will find the most-specific expression anywhere inside
+    /// the provided function call expression, and mark it as responsible for the fullfillment
+    /// error.
+    fn blame_specific_arg_if_possible(
+        &self,
+        error: &mut traits::FulfillmentError<'tcx>,
+        def_id: DefId,
+        param_to_point_at: ty::GenericArg<'tcx>,
+        call_hir_id: hir::HirId,
+        callee_span: Span,
+        receiver: Option<&'tcx hir::Expr<'tcx>>,
+        args: &'tcx [hir::Expr<'tcx>],
+    ) -> bool {
+        let ty = self.tcx.type_of(def_id);
+        if !ty.is_fn() {
+            return false;
+        }
+        let sig = ty.fn_sig(self.tcx).skip_binder();
+        let args_referencing_param: Vec<_> = sig
+            .inputs()
+            .iter()
+            .enumerate()
+            .filter(|(_, ty)| Self::find_param_in_ty((**ty).into(), param_to_point_at))
+            .collect();
+        // If there's one field that references the given generic, great!
+        if let [(idx, _)] = args_referencing_param.as_slice()
+            && let Some(arg) = receiver
+                .map_or(args.get(*idx), |rcvr| if *idx == 0 { Some(rcvr) } else { args.get(*idx - 1) }) {
+
+            error.obligation.cause.span = arg.span.find_ancestor_in_same_ctxt(error.obligation.cause.span).unwrap_or(arg.span);
+
+            if let hir::Node::Expr(arg_expr) = self.tcx.hir().get(arg.hir_id) {
+                // This is more specific than pointing at the entire argument.
+                self.blame_specific_expr_if_possible(error, arg_expr)
+            }
+
+            error.obligation.cause.map_code(|parent_code| {
+                ObligationCauseCode::FunctionArgumentObligation {
+                    arg_hir_id: arg.hir_id,
+                    call_hir_id,
+                    parent_code,
+                }
+            });
+            return true;
+        } else if args_referencing_param.len() > 0 {
+            // If more than one argument applies, then point to the callee span at least...
+            // We have chance to fix this up further in `point_at_generics_if_possible`
+            error.obligation.cause.span = callee_span;
+        }
+
+        false
+    }
+
     /**
      * Recursively searches for the most-specific blamable expression.
      * For example, if you have a chain of constraints like: