]> git.lizzy.rs Git - rust.git/commitdiff
Deriving: Include bound generic params for extracted type parameters in where clause
authorAudun Halland <audun.halland@pm.me>
Tue, 28 Sep 2021 22:46:29 +0000 (00:46 +0200)
committerAudun Halland <audun.halland@pm.me>
Tue, 28 Sep 2021 22:46:29 +0000 (00:46 +0200)
compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
src/test/ui/deriving/issue-89188-gat-hrtb.rs [new file with mode: 0644]

index 59f933d422d01951460eb5cdd7d92e01066dd31e..cd013107a0151b6c95cabc580d712669032e04ce 100644 (file)
@@ -332,20 +332,27 @@ pub fn combine_substructure(
     RefCell::new(f)
 }
 
+struct TypeParameter {
+    bound_generic_params: Vec<ast::GenericParam>,
+    ty: P<ast::Ty>,
+}
+
 /// This method helps to extract all the type parameters referenced from a
 /// type. For a type parameter `<T>`, it looks for either a `TyPath` that
 /// is not global and starts with `T`, or a `TyQPath`.
+/// Also include bound generic params from the input type.
 fn find_type_parameters(
     ty: &ast::Ty,
     ty_param_names: &[Symbol],
     cx: &ExtCtxt<'_>,
-) -> Vec<P<ast::Ty>> {
+) -> Vec<TypeParameter> {
     use rustc_ast::visit;
 
     struct Visitor<'a, 'b> {
         cx: &'a ExtCtxt<'b>,
         ty_param_names: &'a [Symbol],
-        types: Vec<P<ast::Ty>>,
+        bound_generic_params_stack: Vec<ast::GenericParam>,
+        type_params: Vec<TypeParameter>,
     }
 
     impl<'a, 'b> visit::Visitor<'a> for Visitor<'a, 'b> {
@@ -353,7 +360,10 @@ fn visit_ty(&mut self, ty: &'a ast::Ty) {
             if let ast::TyKind::Path(_, ref path) = ty.kind {
                 if let Some(segment) = path.segments.first() {
                     if self.ty_param_names.contains(&segment.ident.name) {
-                        self.types.push(P(ty.clone()));
+                        self.type_params.push(TypeParameter {
+                            bound_generic_params: self.bound_generic_params_stack.clone(),
+                            ty: P(ty.clone()),
+                        });
                     }
                 }
             }
@@ -361,15 +371,35 @@ fn visit_ty(&mut self, ty: &'a ast::Ty) {
             visit::walk_ty(self, ty)
         }
 
+        // Place bound generic params on a stack, to extract them when a type is encountered.
+        fn visit_poly_trait_ref(
+            &mut self,
+            trait_ref: &'a ast::PolyTraitRef,
+            modifier: &'a ast::TraitBoundModifier,
+        ) {
+            let stack_len = trait_ref.bound_generic_params.len();
+            self.bound_generic_params_stack
+                .extend(trait_ref.bound_generic_params.clone().into_iter());
+
+            visit::walk_poly_trait_ref(self, trait_ref, modifier);
+
+            self.bound_generic_params_stack.truncate(stack_len);
+        }
+
         fn visit_mac_call(&mut self, mac: &ast::MacCall) {
             self.cx.span_err(mac.span(), "`derive` cannot be used on items with type macros");
         }
     }
 
-    let mut visitor = Visitor { cx, ty_param_names, types: Vec::new() };
+    let mut visitor = Visitor {
+        cx,
+        ty_param_names,
+        bound_generic_params_stack: Vec::new(),
+        type_params: Vec::new(),
+    };
     visit::Visitor::visit_ty(&mut visitor, ty);
 
-    visitor.types
+    visitor.type_params
 }
 
 impl<'a> TraitDef<'a> {
@@ -617,11 +647,11 @@ fn create_derived_impl(
                     ty_params.map(|ty_param| ty_param.ident.name).collect();
 
                 for field_ty in field_tys {
-                    let tys = find_type_parameters(&field_ty, &ty_param_names, cx);
+                    let field_ty_params = find_type_parameters(&field_ty, &ty_param_names, cx);
 
-                    for ty in tys {
+                    for field_ty_param in field_ty_params {
                         // if we have already handled this type, skip it
-                        if let ast::TyKind::Path(_, ref p) = ty.kind {
+                        if let ast::TyKind::Path(_, ref p) = field_ty_param.ty.kind {
                             if p.segments.len() == 1
                                 && ty_param_names.contains(&p.segments[0].ident.name)
                             {
@@ -639,8 +669,8 @@ fn create_derived_impl(
 
                         let predicate = ast::WhereBoundPredicate {
                             span: self.span,
-                            bound_generic_params: Vec::new(),
-                            bounded_ty: ty,
+                            bound_generic_params: field_ty_param.bound_generic_params,
+                            bounded_ty: field_ty_param.ty,
                             bounds,
                         };
 
diff --git a/src/test/ui/deriving/issue-89188-gat-hrtb.rs b/src/test/ui/deriving/issue-89188-gat-hrtb.rs
new file mode 100644 (file)
index 0000000..3295491
--- /dev/null
@@ -0,0 +1,14 @@
+// check-pass
+
+#![feature(generic_associated_types)]
+
+trait CallWithShim: Sized {
+    type Shim<'s>
+    where
+        Self: 's;
+}
+
+#[derive(Clone)]
+struct ShimMethod<T: CallWithShim + 'static>(pub &'static dyn for<'s> Fn(&'s mut T::Shim<'s>));
+
+pub fn main() {}