]> git.lizzy.rs Git - rust.git/commitdiff
syntax: Update #[derive(...)] to work with phantom and associated types
authorErick Tryzelaar <erick.tryzelaar@gmail.com>
Tue, 24 Mar 2015 21:43:26 +0000 (14:43 -0700)
committerErick Tryzelaar <erick.tryzelaar@gmail.com>
Tue, 24 Mar 2015 21:43:26 +0000 (14:43 -0700)
Closes #7671, #19839

src/libsyntax/ext/deriving/generic/mod.rs
src/test/compile-fail/derive-assoc-type-not-impl.rs [new file with mode: 0644]
src/test/run-pass/deriving-associated-types.rs [new file with mode: 0644]

index 58b6d96607df777454cb9b841e60fddc4c77848e..0c5e4d67642acc08a82fdbd838ef4dbf21de7dec 100644 (file)
@@ -332,6 +332,46 @@ pub fn combine_substructure<'a>(f: CombineSubstructureFunc<'a>)
     RefCell::new(f)
 }
 
+/// This method helps to extract all the type parameters referenced from a
+/// type. For a type parameter `<T>`, it looks for either a `TyPath` that
+/// is not global and starts with `T`, or a `TyQPath`.
+fn find_type_parameters(ty: &ast::Ty, ty_param_names: &[ast::Name]) -> Vec<P<ast::Ty>> {
+    use visit;
+
+    struct Visitor<'a> {
+        ty_param_names: &'a [ast::Name],
+        types: Vec<P<ast::Ty>>,
+    }
+
+    impl<'a> visit::Visitor<'a> for Visitor<'a> {
+        fn visit_ty(&mut self, ty: &'a ast::Ty) {
+            match ty.node {
+                ast::TyPath(_, ref path) if !path.global => {
+                    match path.segments.first() {
+                        Some(segment) => {
+                            if self.ty_param_names.contains(&segment.identifier.name) {
+                                self.types.push(P(ty.clone()));
+                            }
+                        }
+                        None => {}
+                    }
+                }
+                _ => {}
+            }
+
+            visit::walk_ty(self, ty)
+        }
+    }
+
+    let mut visitor = Visitor {
+        ty_param_names: ty_param_names,
+        types: Vec::new(),
+    };
+
+    visit::Visitor::visit_ty(&mut visitor, ty);
+
+    visitor.types
+}
 
 impl<'a> TraitDef<'a> {
     pub fn expand<F>(&self,
@@ -374,18 +414,42 @@ pub fn expand<F>(&self,
         }))
     }
 
-    /// Given that we are deriving a trait `Tr` for a type `T<'a, ...,
-    /// 'z, A, ..., Z>`, creates an impl like:
+    /// Given that we are deriving a trait `DerivedTrait` for a type like:
     ///
     /// ```ignore
-    /// impl<'a, ..., 'z, A:Tr B1 B2, ..., Z: Tr B1 B2> Tr for T<A, ..., Z> { ... }
+    /// struct Struct<'a, ..., 'z, A, B: DeclaredTrait, C, ..., Z> where C: WhereTrait {
+    ///     a: A,
+    ///     b: B::Item,
+    ///     b1: <B as DeclaredTrait>::Item,
+    ///     c1: <C as WhereTrait>::Item,
+    ///     c2: Option<<C as WhereTrait>::Item>,
+    ///     ...
+    /// }
+    /// ```
+    ///
+    /// create an impl like:
+    ///
+    /// ```ignore
+    /// impl<'a, ..., 'z, A, B: DeclaredTrait, C, ...  Z> where
+    ///     C:                       WhereTrait,
+    ///     A: DerivedTrait + B1 + ... + BN,
+    ///     B: DerivedTrait + B1 + ... + BN,
+    ///     C: DerivedTrait + B1 + ... + BN,
+    ///     B::Item:                 DerivedTrait + B1 + ... + BN,
+    ///     <C as WhereTrait>::Item: DerivedTrait + B1 + ... + BN,
+    ///     ...
+    /// {
+    ///     ...
+    /// }
     /// ```
     ///
-    /// where B1, B2, ... are the bounds given by `bounds_paths`.'
+    /// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
+    /// therefore does not get bound by the derived trait.
     fn create_derived_impl(&self,
                            cx: &mut ExtCtxt,
                            type_ident: Ident,
                            generics: &Generics,
+                           field_tys: Vec<P<ast::Ty>>,
                            methods: Vec<P<ast::ImplItem>>) -> P<ast::Item> {
         let trait_path = self.path.to_path(cx, self.span, type_ident, generics);
 
@@ -466,6 +530,35 @@ fn create_derived_impl(&self,
             }
         }));
 
+        if !ty_params.is_empty() {
+            let ty_param_names: Vec<ast::Name> = ty_params.iter()
+                .map(|ty_param| ty_param.ident.name)
+                .collect();
+
+            for field_ty in field_tys.into_iter() {
+                let tys = find_type_parameters(&*field_ty, &ty_param_names);
+
+                for ty in tys.into_iter() {
+                    let mut bounds: Vec<_> = self.additional_bounds.iter().map(|p| {
+                        cx.typarambound(p.to_path(cx, self.span, type_ident, generics))
+                    }).collect();
+
+                    // require the current trait
+                    bounds.push(cx.typarambound(trait_path.clone()));
+
+                    let predicate = ast::WhereBoundPredicate {
+                        span: self.span,
+                        bound_lifetimes: vec![],
+                        bounded_ty: ty,
+                        bounds: OwnedSlice::from_vec(bounds),
+                    };
+
+                    let predicate = ast::WherePredicate::BoundPredicate(predicate);
+                    where_clause.predicates.push(predicate);
+                }
+            }
+        }
+
         let trait_generics = Generics {
             lifetimes: lifetimes,
             ty_params: OwnedSlice::from_vec(ty_params),
@@ -518,6 +611,10 @@ fn expand_struct_def(&self,
                          struct_def: &StructDef,
                          type_ident: Ident,
                          generics: &Generics) -> P<ast::Item> {
+        let field_tys: Vec<P<ast::Ty>> = struct_def.fields.iter()
+            .map(|field| field.node.ty.clone())
+            .collect();
+
         let methods = self.methods.iter().map(|method_def| {
             let (explicit_self, self_args, nonself_args, tys) =
                 method_def.split_self_nonself_args(
@@ -550,7 +647,7 @@ fn expand_struct_def(&self,
                                      body)
         }).collect();
 
-        self.create_derived_impl(cx, type_ident, generics, methods)
+        self.create_derived_impl(cx, type_ident, generics, field_tys, methods)
     }
 
     fn expand_enum_def(&self,
@@ -558,6 +655,21 @@ fn expand_enum_def(&self,
                        enum_def: &EnumDef,
                        type_ident: Ident,
                        generics: &Generics) -> P<ast::Item> {
+        let mut field_tys = Vec::new();
+
+        for variant in enum_def.variants.iter() {
+            match variant.node.kind {
+                ast::VariantKind::TupleVariantKind(ref args) => {
+                    field_tys.extend(args.iter()
+                        .map(|arg| arg.ty.clone()));
+                }
+                ast::VariantKind::StructVariantKind(ref args) => {
+                    field_tys.extend(args.fields.iter()
+                        .map(|field| field.node.ty.clone()));
+                }
+            }
+        }
+
         let methods = self.methods.iter().map(|method_def| {
             let (explicit_self, self_args, nonself_args, tys) =
                 method_def.split_self_nonself_args(cx, self,
@@ -590,7 +702,7 @@ fn expand_enum_def(&self,
                                      body)
         }).collect();
 
-        self.create_derived_impl(cx, type_ident, generics, methods)
+        self.create_derived_impl(cx, type_ident, generics, field_tys, methods)
     }
 }
 
diff --git a/src/test/compile-fail/derive-assoc-type-not-impl.rs b/src/test/compile-fail/derive-assoc-type-not-impl.rs
new file mode 100644 (file)
index 0000000..3799f2f
--- /dev/null
@@ -0,0 +1,29 @@
+// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+trait Foo {
+    type X;
+    fn method(&self) {}
+}
+
+#[derive(Clone)]
+struct Bar<T: Foo> {
+    x: T::X,
+}
+
+struct NotClone;
+
+impl Foo for NotClone {
+    type X = i8;
+}
+
+fn main() {
+    Bar::<NotClone> { x: 1 }.clone(); //~ ERROR
+}
diff --git a/src/test/run-pass/deriving-associated-types.rs b/src/test/run-pass/deriving-associated-types.rs
new file mode 100644 (file)
index 0000000..59eb550
--- /dev/null
@@ -0,0 +1,210 @@
+// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![feature(core, debug_builders)]
+
+pub trait DeclaredTrait {
+    type Type;
+}
+
+impl DeclaredTrait for i32 {
+    type Type = i32;
+}
+
+pub trait WhereTrait {
+    type Type;
+}
+
+impl WhereTrait for i32 {
+    type Type = i32;
+}
+
+// Make sure we don't add a bound that just shares a name with an associated
+// type.
+pub mod module {
+    pub type Type = i32;
+}
+
+#[derive(PartialEq, Debug)]
+struct PrivateStruct<T>(T);
+
+#[derive(PartialEq, Debug)]
+struct TupleStruct<A, B: DeclaredTrait, C>(
+    module::Type,
+    Option<module::Type>,
+    A,
+    PrivateStruct<A>,
+    B,
+    B::Type,
+    Option<B::Type>,
+    <B as DeclaredTrait>::Type,
+    Option<<B as DeclaredTrait>::Type>,
+    C,
+    C::Type,
+    Option<C::Type>,
+    <C as WhereTrait>::Type,
+    Option<<C as WhereTrait>::Type>,
+    <i32 as DeclaredTrait>::Type,
+) where C: WhereTrait;
+
+#[derive(PartialEq, Debug)]
+pub struct Struct<A, B: DeclaredTrait, C> where C: WhereTrait {
+    m1: module::Type,
+    m2: Option<module::Type>,
+    a1: A,
+    a2: PrivateStruct<A>,
+    b: B,
+    b1: B::Type,
+    b2: Option<B::Type>,
+    b3: <B as DeclaredTrait>::Type,
+    b4: Option<<B as DeclaredTrait>::Type>,
+    c: C,
+    c1: C::Type,
+    c2: Option<C::Type>,
+    c3: <C as WhereTrait>::Type,
+    c4: Option<<C as WhereTrait>::Type>,
+    d: <i32 as DeclaredTrait>::Type,
+}
+
+#[derive(PartialEq, Debug)]
+enum Enum<A, B: DeclaredTrait, C> where C: WhereTrait {
+    Unit,
+    Seq(
+        module::Type,
+        Option<module::Type>,
+        A,
+        PrivateStruct<A>,
+        B,
+        B::Type,
+        Option<B::Type>,
+        <B as DeclaredTrait>::Type,
+        Option<<B as DeclaredTrait>::Type>,
+        C,
+        C::Type,
+        Option<C::Type>,
+        <C as WhereTrait>::Type,
+        Option<<C as WhereTrait>::Type>,
+        <i32 as DeclaredTrait>::Type,
+    ),
+    Map {
+        m1: module::Type,
+        m2: Option<module::Type>,
+        a1: A,
+        a2: PrivateStruct<A>,
+        b: B,
+        b1: B::Type,
+        b2: Option<B::Type>,
+        b3: <B as DeclaredTrait>::Type,
+        b4: Option<<B as DeclaredTrait>::Type>,
+        c: C,
+        c1: C::Type,
+        c2: Option<C::Type>,
+        c3: <C as WhereTrait>::Type,
+        c4: Option<<C as WhereTrait>::Type>,
+        d: <i32 as DeclaredTrait>::Type,
+    },
+}
+
+fn main() {
+    let e: TupleStruct<
+        i32,
+        i32,
+        i32,
+    > = TupleStruct(
+        0,
+        None,
+        0,
+        PrivateStruct(0),
+        0,
+        0,
+        None,
+        0,
+        None,
+        0,
+        0,
+        None,
+        0,
+        None,
+        0,
+    );
+    assert_eq!(e, e);
+
+    let e: Struct<
+        i32,
+        i32,
+        i32,
+    > = Struct {
+        m1: 0,
+        m2: None,
+        a1: 0,
+        a2: PrivateStruct(0),
+        b: 0,
+        b1: 0,
+        b2: None,
+        b3: 0,
+        b4: None,
+        c: 0,
+        c1: 0,
+        c2: None,
+        c3: 0,
+        c4: None,
+        d: 0,
+    };
+    assert_eq!(e, e);
+
+    let e = Enum::Unit::<i32, i32, i32>;
+    assert_eq!(e, e);
+
+    let e: Enum<
+        i32,
+        i32,
+        i32,
+    > = Enum::Seq(
+        0,
+        None,
+        0,
+        PrivateStruct(0),
+        0,
+        0,
+        None,
+        0,
+        None,
+        0,
+        0,
+        None,
+        0,
+        None,
+        0,
+    );
+    assert_eq!(e, e);
+
+    let e: Enum<
+        i32,
+        i32,
+        i32,
+    > = Enum::Map {
+        m1: 0,
+        m2: None,
+        a1: 0,
+        a2: PrivateStruct(0),
+        b: 0,
+        b1: 0,
+        b2: None,
+        b3: 0,
+        b4: None,
+        c: 0,
+        c1: 0,
+        c2: None,
+        c3: 0,
+        c4: None,
+        d: 0,
+    };
+    assert_eq!(e, e);
+}