]> git.lizzy.rs Git - rust.git/commitdiff
For associated type shorthand (T::Item), use the substs from the where clause
authorFlorian Diebold <flodiebold@gmail.com>
Sun, 26 Apr 2020 14:56:25 +0000 (16:56 +0200)
committerFlorian Diebold <flodiebold@gmail.com>
Sun, 26 Apr 2020 14:58:08 +0000 (16:58 +0200)
So e.g. if we have `fn foo<T: SomeTrait<u32>>() -> T::Item`, we want to lower
that to `<T as SomeTrait<u32>>::Item` and not `<T as SomeTrait<_>>::Item`.

crates/ra_hir_ty/src/lib.rs
crates/ra_hir_ty/src/lower.rs
crates/ra_hir_ty/src/tests/traits.rs
crates/ra_hir_ty/src/utils.rs

index 279c06d65d197e2ed8bc97e4a50e691fb850964c..a8ef32ec59730b08f2353b87bd13e7e5755bff52 100644 (file)
@@ -487,6 +487,18 @@ impl<T> Binders<T> {
     pub fn new(num_binders: usize, value: T) -> Self {
         Self { num_binders, value }
     }
+
+    pub fn as_ref(&self) -> Binders<&T> {
+        Binders { num_binders: self.num_binders, value: &self.value }
+    }
+
+    pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Binders<U> {
+        Binders { num_binders: self.num_binders, value: f(self.value) }
+    }
+
+    pub fn filter_map<U>(self, f: impl FnOnce(T) -> Option<U>) -> Option<Binders<U>> {
+        Some(Binders { num_binders: self.num_binders, value: f(self.value)? })
+    }
 }
 
 impl<T: Clone> Binders<&T> {
index b572142966165c96421d76343736aebcb6c5ddf6..a6f893037f5dde27454ebfb3f45025921f3d4ff6 100644 (file)
     db::HirDatabase,
     primitive::{FloatTy, IntTy},
     utils::{
-        all_super_traits, associated_type_by_name_including_super_traits, generics, make_mut_slice,
-        variant_data,
+        all_super_trait_refs, associated_type_by_name_including_super_traits, generics,
+        make_mut_slice, variant_data,
     },
     Binders, BoundVar, DebruijnIndex, FnSig, GenericPredicate, PolyFnSig, ProjectionPredicate,
-    ProjectionTy, Substs, TraitEnvironment, TraitRef, Ty, TypeCtor,
+    ProjectionTy, Substs, TraitEnvironment, TraitRef, Ty, TypeCtor, TypeWalk,
 };
 
 #[derive(Debug)]
@@ -256,7 +256,7 @@ pub(crate) fn from_type_relative_path(
         if remaining_segments.len() == 1 {
             // resolve unselected assoc types
             let segment = remaining_segments.first().unwrap();
-            (Ty::select_associated_type(ctx, ty, res, segment), None)
+            (Ty::select_associated_type(ctx, res, segment), None)
         } else if remaining_segments.len() > 1 {
             // FIXME report error (ambiguous associated type)
             (Ty::Unknown, None)
@@ -380,21 +380,20 @@ pub(crate) fn from_hir_path(ctx: &TyLoweringContext<'_>, path: &Path) -> (Ty, Op
 
     fn select_associated_type(
         ctx: &TyLoweringContext<'_>,
-        self_ty: Ty,
         res: Option<TypeNs>,
         segment: PathSegment<'_>,
     ) -> Ty {
         let traits_from_env: Vec<_> = match res {
             Some(TypeNs::SelfType(impl_id)) => match ctx.db.impl_trait(impl_id) {
                 None => return Ty::Unknown,
-                Some(trait_ref) => vec![trait_ref.value.trait_],
+                Some(trait_ref) => vec![trait_ref.value],
             },
             Some(TypeNs::GenericParam(param_id)) => {
                 let predicates = ctx.db.generic_predicates_for_param(param_id);
                 let mut traits_: Vec<_> = predicates
                     .iter()
                     .filter_map(|pred| match &pred.value {
-                        GenericPredicate::Implemented(tr) => Some(tr.trait_),
+                        GenericPredicate::Implemented(tr) => Some(tr.clone()),
                         _ => None,
                     })
                     .collect();
@@ -404,20 +403,37 @@ fn select_associated_type(
                     if generics.params.types[param_id.local_id].provenance
                         == TypeParamProvenance::TraitSelf
                     {
-                        traits_.push(trait_id);
+                        let trait_ref = TraitRef {
+                            trait_: trait_id,
+                            substs: Substs::bound_vars(&generics, DebruijnIndex::INNERMOST),
+                        };
+                        traits_.push(trait_ref);
                     }
                 }
                 traits_
             }
             _ => return Ty::Unknown,
         };
-        let traits = traits_from_env.into_iter().flat_map(|t| all_super_traits(ctx.db.upcast(), t));
+        let traits = traits_from_env.into_iter().flat_map(|t| all_super_trait_refs(ctx.db, t));
         for t in traits {
-            if let Some(associated_ty) = ctx.db.trait_data(t).associated_type_by_name(&segment.name)
+            if let Some(associated_ty) =
+                ctx.db.trait_data(t.trait_).associated_type_by_name(&segment.name)
             {
-                let substs =
-                    Substs::build_for_def(ctx.db, t).push(self_ty).fill_with_unknown().build();
-                // FIXME handle type parameters on the segment
+                let substs = match ctx.type_param_mode {
+                    TypeParamLoweringMode::Placeholder => {
+                        // if we're lowering to placeholders, we have to put
+                        // them in now
+                        let s = Substs::type_params(
+                            ctx.db,
+                            ctx.resolver
+                                .generic_def()
+                                .expect("there should be generics if there's a generic param"),
+                        );
+                        t.substs.subst_bound_vars(&s)
+                    }
+                    TypeParamLoweringMode::Variable => t.substs,
+                };
+                // FIXME handle (forbid) type parameters on the segment
                 return Ty::Projection(ProjectionTy { associated_ty, parameters: substs });
             }
         }
index f51cdd4964b56e1b48a581fdc6f12e79dd9df48b..e555c879a0454922c10998735f8013d728adc354 100644 (file)
@@ -1897,6 +1897,36 @@ fn test() {
     assert_eq!(t, "u32");
 }
 
+#[test]
+fn unselected_projection_chalk_fold() {
+    let t = type_at(
+        r#"
+//- /main.rs
+trait Interner {}
+trait Fold<I: Interner, TI = I> {
+    type Result;
+}
+
+struct Ty<I: Interner> {}
+impl<I: Interner, TI: Interner> Fold<I, TI> for Ty<I> {
+    type Result = Ty<TI>;
+}
+
+fn fold<I: Interner, T>(interner: &I, t: T) -> T::Result
+where
+    T: Fold<I, I>,
+{
+    loop {}
+}
+
+fn foo<I: Interner>(interner: &I, t: Ty<I>) {
+    fold(interner, t)<|>;
+}
+"#,
+    );
+    assert_eq!(t, "Ty<I>");
+}
+
 #[test]
 fn trait_impl_self_ty() {
     let t = type_at(
index 1e5022fa4c73b5738040ce292e602c62e4e7d24b..f98350bf92187369124a8090111423d78437cfcc 100644 (file)
@@ -14,6 +14,8 @@
 };
 use hir_expand::name::{name, Name};
 
+use crate::{db::HirDatabase, GenericPredicate, TraitRef};
+
 fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
     let resolver = trait_.resolver(db);
     // returning the iterator directly doesn't easily work because of
@@ -41,6 +43,28 @@ fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
         .collect()
 }
 
+fn direct_super_trait_refs(db: &dyn HirDatabase, trait_ref: &TraitRef) -> Vec<TraitRef> {
+    // returning the iterator directly doesn't easily work because of
+    // lifetime problems, but since there usually shouldn't be more than a
+    // few direct traits this should be fine (we could even use some kind of
+    // SmallVec if performance is a concern)
+    let generic_params = db.generic_params(trait_ref.trait_.into());
+    let trait_self = match generic_params.find_trait_self_param() {
+        Some(p) => TypeParamId { parent: trait_ref.trait_.into(), local_id: p },
+        None => return Vec::new(),
+    };
+    db.generic_predicates_for_param(trait_self)
+        .iter()
+        .filter_map(|pred| {
+            pred.as_ref().filter_map(|pred| match pred {
+                GenericPredicate::Implemented(tr) => Some(tr.clone()),
+                _ => None,
+            })
+        })
+        .map(|pred| pred.subst(&trait_ref.substs))
+        .collect()
+}
+
 /// Returns an iterator over the whole super trait hierarchy (including the
 /// trait itself).
 pub(super) fn all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<TraitId> {
@@ -62,6 +86,30 @@ pub(super) fn all_super_traits(db: &dyn DefDatabase, trait_: TraitId) -> Vec<Tra
     result
 }
 
+/// Given a trait ref (`Self: Trait`), builds all the implied trait refs for
+/// super traits. The original trait ref will be included. So the difference to
+/// `all_super_traits` is that we keep track of type parameters; for example if
+/// we have `Self: Trait<u32, i32>` and `Trait<T, U>: OtherTrait<U>` we'll get
+/// `Self: OtherTrait<i32>`.
+pub(super) fn all_super_trait_refs(db: &dyn HirDatabase, trait_ref: TraitRef) -> Vec<TraitRef> {
+    // we need to take care a bit here to avoid infinite loops in case of cycles
+    // (i.e. if we have `trait A: B; trait B: A;`)
+    let mut result = vec![trait_ref];
+    let mut i = 0;
+    while i < result.len() {
+        let t = &result[i];
+        // yeah this is quadratic, but trait hierarchies should be flat
+        // enough that this doesn't matter
+        for tt in direct_super_trait_refs(db, t) {
+            if !result.iter().any(|tr| tr.trait_ == tt.trait_) {
+                result.push(tt);
+            }
+        }
+        i += 1;
+    }
+    result
+}
+
 /// Finds a path from a trait to one of its super traits. Returns an empty
 /// vector if there is no path.
 pub(super) fn find_super_trait_path(