]> git.lizzy.rs Git - enumset.git/blobdiff - enumset_derive/src/lib.rs
Fix compilation with serde flag and other traits that define serialize/deserialize.
[enumset.git] / enumset_derive / src / lib.rs
index e977b254bcfac4fc6853c5044847e2ef9e69c1e7..b4dcad193d5816724a301443d63a5c38ef003ee4 100644 (file)
@@ -1,14 +1,11 @@
-#![recursion_limit="128"]
+#![recursion_limit="256"]
 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
 
-extern crate syn;
 extern crate proc_macro;
-extern crate proc_macro2;
-extern crate quote;
 
-use self::proc_macro::{TokenStream, TokenTree, Literal};
-
-use proc_macro2::{TokenStream as SynTokenStream};
+use darling::*;
+use proc_macro::TokenStream;
+use proc_macro2::{TokenStream as SynTokenStream, Literal};
 use syn::*;
 use syn::export::Span;
 use syn::spanned::Spanned;
@@ -26,16 +23,20 @@ fn error(_: Span, data: &str) -> TokenStream {
 }
 
 fn enum_set_type_impl(
-    name: &Ident, all_variants: u128, repr: Ident, no_ops: bool, no_derives: bool,
+    name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
 ) -> SynTokenStream {
+    let is_uninhabited = variants.is_empty();
+    let is_zst = variants.len() == 1;
+
     let typed_enumset = quote!(::enumset::EnumSet<#name>);
-    let core = quote!(::enumset::internal::core);
+    let core = quote!(::enumset::internal::core_export);
+    #[cfg(feature = "serde")]
+    let serde = quote!(::enumset::internal::serde);
 
     // proc_macro2 does not support creating u128 literals.
-    let all_variants_tt = TokenTree::Literal(Literal::u128_unsuffixed(all_variants));
-    let all_variants_tt = SynTokenStream::from(TokenStream::from(all_variants_tt));
+    let all_variants = Literal::u128_unsuffixed(all_variants);
 
-    let ops = if no_ops {
+    let ops = if attrs.no_ops {
         quote! {}
     } else {
         quote! {
@@ -77,45 +78,104 @@ fn enum_set_type_impl(
         }
     };
 
-    let derives = if no_derives {
-        quote! {}
-    } else {
+    #[cfg(feature = "serde")]
+    let serde_ops = if attrs.serialize_as_list {
+        let expecting_str = format!("a list of {}", name);
         quote! {
-            impl #core::cmp::PartialOrd for #name {
-                fn partial_cmp(&self, other: &Self) -> #core::option::Option<#core::cmp::Ordering> {
-                    (*self as u8).partial_cmp(&(*other as u8))
+            fn serialize<S: #serde::Serializer>(
+                set: ::enumset::EnumSet<#name>, ser: S,
+            ) -> #core::result::Result<S::Ok, S::Error> {
+                use #serde::ser::SerializeSeq;
+                let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
+                for bit in set {
+                    seq.serialize_element(&bit)?;
                 }
+                seq.end()
             }
-            impl #core::cmp::Ord for #name {
-                fn cmp(&self, other: &Self) -> #core::cmp::Ordering {
-                    (*self as u8).cmp(&(*other as u8))
+            fn deserialize<'de, D: #serde::Deserializer<'de>>(
+                de: D,
+            ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
+                struct Visitor;
+                impl <'de> #serde::de::Visitor<'de> for Visitor {
+                    type Value = ::enumset::EnumSet<#name>;
+                    fn expecting(
+                        &self, formatter: &mut #core::fmt::Formatter,
+                    ) -> #core::fmt::Result {
+                        write!(formatter, #expecting_str)
+                    }
+                    fn visit_seq<A>(
+                        mut self, mut seq: A,
+                    ) -> #core::result::Result<Self::Value, A::Error> where
+                        A: #serde::de::SeqAccess<'de>
+                    {
+                        let mut accum = ::enumset::EnumSet::<#name>::new();
+                        while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
+                            accum |= val;
+                        }
+                        #core::prelude::v1::Ok(accum)
+                    }
                 }
+                de.deserialize_seq(Visitor)
             }
-            impl #core::cmp::PartialEq for #name {
-                fn eq(&self, other: &Self) -> bool {
-                    (*self as u8) == (*other as u8)
+        }
+    } else {
+        let serialize_repr = attrs.serialize_repr.as_ref()
+            .map(|x| Ident::new(&x, Span::call_site()))
+            .unwrap_or(repr.clone());
+        let check_unknown = if attrs.serialize_deny_unknown {
+            quote! {
+                if value & !#all_variants != 0 {
+                    use #serde::de::Error;
+                    return #core::prelude::v1::Err(
+                        D::Error::custom("enumset contains unknown bits")
+                    )
                 }
             }
-            impl #core::cmp::Eq for #name { }
-            impl #core::hash::Hash for #name {
-                fn hash<H: #core::hash::Hasher>(&self, state: &mut H) {
-                    state.write_u8(*self as u8)
-                }
+        } else {
+            quote! { }
+        };
+        quote! {
+            fn serialize<S: #serde::Serializer>(
+                set: ::enumset::EnumSet<#name>, ser: S,
+            ) -> #core::result::Result<S::Ok, S::Error> {
+                #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
             }
-            impl #core::clone::Clone for #name {
-                fn clone(&self) -> Self {
-                    *self
-                }
+            fn deserialize<'de, D: #serde::Deserializer<'de>>(
+                de: D,
+            ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
+                let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
+                #check_unknown
+                #core::prelude::v1::Ok(::enumset::EnumSet {
+                    __enumset_underlying: (value & #all_variants) as #repr,
+                })
             }
-            impl #core::marker::Copy for #name { }
         }
     };
 
-    quote! {
-        unsafe impl ::enumset::EnumSetType for #name {
-            type Repr = #repr;
-            const ALL_BITS: Self::Repr = #all_variants_tt;
+    #[cfg(not(feature = "serde"))]
+    let serde_ops = quote! { };
 
+    let into_impl = if is_uninhabited {
+        quote! {
+            fn enum_into_u8(self) -> u8 {
+                panic!(concat!(stringify!(#name), " is uninhabited."))
+            }
+            unsafe fn enum_from_u8(val: u8) -> Self {
+                panic!(concat!(stringify!(#name), " is uninhabited."))
+            }
+        }
+    } else if is_zst {
+        let variant = &variants[0];
+        quote! {
+            fn enum_into_u8(self) -> u8 {
+                self as u8
+            }
+            unsafe fn enum_from_u8(val: u8) -> Self {
+                #name::#variant
+            }
+        }
+    } else {
+        quote! {
             fn enum_into_u8(self) -> u8 {
                 self as u8
             }
@@ -123,16 +183,55 @@ fn enum_set_type_impl(
                 #core::mem::transmute(val)
             }
         }
+    };
+
+    let eq_impl = if is_uninhabited {
+        quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
+    } else {
+        quote!((*self as u8) == (*other as u8))
+    };
+
+    quote! {
+        unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
+            type Repr = #repr;
+            const ALL_BITS: Self::Repr = #all_variants;
+            #into_impl
+            #serde_ops
+        }
+
+        unsafe impl ::enumset::EnumSetType for #name { }
+
+        impl #core::cmp::PartialEq for #name {
+            fn eq(&self, other: &Self) -> bool {
+                #eq_impl
+            }
+        }
+        impl #core::cmp::Eq for #name { }
+        impl #core::clone::Clone for #name {
+            fn clone(&self) -> Self {
+                *self
+            }
+        }
+        impl #core::marker::Copy for #name { }
 
         #ops
-        #derives
     }
 }
 
-#[proc_macro_derive(EnumSetType, attributes(enumset_no_ops, enumset_no_derives))]
+#[derive(FromDeriveInput, Default)]
+#[darling(attributes(enumset), default)]
+struct EnumsetAttrs {
+    no_ops: bool,
+    serialize_as_list: bool,
+    serialize_deny_unknown: bool,
+    #[darling(default)]
+    serialize_repr: Option<String>,
+}
+
+#[proc_macro_derive(EnumSetType, attributes(enumset))]
 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
     let input: DeriveInput = parse_macro_input!(input);
-    if let Data::Enum(data) = input.data {
+    if let Data::Enum(data) = &input.data {
         if !input.generics.params.is_empty() {
             error(input.generics.span(),
                   "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.")
@@ -141,12 +240,16 @@ pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
             let mut max_variant = 0;
             let mut current_variant = 0;
             let mut has_manual_discriminant = false;
+            let mut variants = Vec::new();
 
             for variant in &data.variants {
                 if let Fields::Unit = variant.fields {
                     if let Some((_, expr)) = &variant.discriminant {
                         if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
-                            current_variant = i.value();
+                            current_variant = match i.base10_parse() {
+                                Ok(val) => val,
+                                Err(_) => return error(expr.span(), "Error parsing discriminant."),
+                            };
                             has_manual_discriminant = true;
                         } else {
                             return error(variant.span(), "Unrecognized discriminant for variant.")
@@ -171,6 +274,7 @@ pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
                         max_variant = current_variant
                     }
 
+                    variants.push(variant.ident.clone());
                     current_variant += 1;
                 } else {
                     return error(variant.span(),
@@ -178,44 +282,47 @@ pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
                 }
             }
 
-            let repr = Ident::new(if max_variant <= 8 {
+            let repr = Ident::new(if max_variant <= 7 {
                 "u8"
-            } else if max_variant <= 16 {
+            } else if max_variant <= 15 {
                 "u16"
-            } else if max_variant <= 32 {
+            } else if max_variant <= 31 {
                 "u32"
-            } else if max_variant <= 64 {
+            } else if max_variant <= 63 {
                 "u64"
-            } else if max_variant <= 128 {
+            } else if max_variant <= 127 {
                 "u128"
             } else {
-                panic!("max_variant > 128?")
+                panic!("max_variant > 127?")
             }, Span::call_site());
 
-            let mut no_ops = false;
-            let mut no_derives = false;
-
-            for attr in &input.attrs {
-                let span = attr.span();
-                let Attribute { tts, path: Path { segments, ..}, .. } = attr;
+            let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
+                Ok(attrs) => attrs,
+                Err(e) => return e.write_errors().into(),
+            };
 
-                if segments.len() == 1 && segments[0].ident.to_string() == "enumset_no_ops" {
-                    no_ops = true;
-                    if !tts.is_empty() {
-                        return error(span, "`#[enumset_no_ops]` takes no arguments.")
-                    }
+            match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
+                Some("u8") => if max_variant > 7 {
+                    return error(input.span(), "Too many variants for u8 serialization repr.")
                 }
-                if segments.len() == 1 && segments[0].ident.to_string() == "enumset_no_derives" {
-                    no_derives = true;
-                    if !tts.is_empty() {
-                        return error(span, "`#[enumset_no_derives]` takes no arguments.")
-                    }
+                Some("u16") => if max_variant > 15 {
+                    return error(input.span(), "Too many variants for u16 serialization repr.")
                 }
-            }
+                Some("u32") => if max_variant > 31 {
+                    return error(input.span(), "Too many variants for u32 serialization repr.")
+                }
+                Some("u64") => if max_variant > 63 {
+                    return error(input.span(), "Too many variants for u64 serialization repr.")
+                }
+                Some("u128") => if max_variant > 127 {
+                    return error(input.span(), "Too many variants for u128 serialization repr.")
+                }
+                None => { }
+                Some(x) => return error(input.span(),
+                                        &format!("{} is not a valid serialization repr.", x)),
+            };
 
-            enum_set_type_impl(
-                &input.ident, all_variants, repr, no_ops, no_derives,
-            ).into()
+            enum_set_type_impl(&input.ident, all_variants, repr, attrs, variants).into()
         }
     } else {
         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")