-#![recursion_limit="128"]
+#![recursion_limit="256"]
#![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
-extern crate syn;
extern crate proc_macro;
-extern crate proc_macro2;
-extern crate quote;
-use self::proc_macro::{TokenStream, TokenTree, Literal};
-
-use proc_macro2::{TokenStream as SynTokenStream};
+use darling::*;
+use proc_macro::TokenStream;
+use proc_macro2::{TokenStream as SynTokenStream, Literal};
use syn::*;
use syn::export::Span;
use syn::spanned::Spanned;
}
fn enum_set_type_impl(
- name: &Ident, all_variants: u128, repr: Ident, no_ops: bool, no_derives: bool,
+ name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
) -> SynTokenStream {
+ let is_uninhabited = variants.is_empty();
+ let is_zst = variants.len() == 1;
+
let typed_enumset = quote!(::enumset::EnumSet<#name>);
- let core = quote!(::enumset::internal::core);
+ let core = quote!(::enumset::internal::core_export);
+ #[cfg(feature = "serde")]
+ let serde = quote!(::enumset::internal::serde);
// proc_macro2 does not support creating u128 literals.
- let all_variants_tt = TokenTree::Literal(Literal::u128_unsuffixed(all_variants));
- let all_variants_tt = SynTokenStream::from(TokenStream::from(all_variants_tt));
+ let all_variants = Literal::u128_unsuffixed(all_variants);
- let ops = if no_ops {
+ let ops = if attrs.no_ops {
quote! {}
} else {
quote! {
}
};
- let derives = if no_derives {
- quote! {}
- } else {
+ #[cfg(feature = "serde")]
+ let serde_ops = if attrs.serialize_as_list {
+ let expecting_str = format!("a list of {}", name);
quote! {
- impl #core::cmp::PartialOrd for #name {
- fn partial_cmp(&self, other: &Self) -> #core::option::Option<#core::cmp::Ordering> {
- (*self as u8).partial_cmp(&(*other as u8))
+ fn serialize<S: #serde::Serializer>(
+ set: ::enumset::EnumSet<#name>, ser: S,
+ ) -> #core::result::Result<S::Ok, S::Error> {
+ use #serde::ser::SerializeSeq;
+ let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
+ for bit in set {
+ seq.serialize_element(&bit)?;
}
+ seq.end()
}
- impl #core::cmp::Ord for #name {
- fn cmp(&self, other: &Self) -> #core::cmp::Ordering {
- (*self as u8).cmp(&(*other as u8))
+ fn deserialize<'de, D: #serde::Deserializer<'de>>(
+ de: D,
+ ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
+ struct Visitor;
+ impl <'de> #serde::de::Visitor<'de> for Visitor {
+ type Value = ::enumset::EnumSet<#name>;
+ fn expecting(
+ &self, formatter: &mut #core::fmt::Formatter,
+ ) -> #core::fmt::Result {
+ write!(formatter, #expecting_str)
+ }
+ fn visit_seq<A>(
+ mut self, mut seq: A,
+ ) -> #core::result::Result<Self::Value, A::Error> where
+ A: #serde::de::SeqAccess<'de>
+ {
+ let mut accum = ::enumset::EnumSet::<#name>::new();
+ while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
+ accum |= val;
+ }
+ #core::prelude::v1::Ok(accum)
+ }
}
+ de.deserialize_seq(Visitor)
}
- impl #core::cmp::PartialEq for #name {
- fn eq(&self, other: &Self) -> bool {
- (*self as u8) == (*other as u8)
+ }
+ } else {
+ let serialize_repr = attrs.serialize_repr.as_ref()
+ .map(|x| Ident::new(&x, Span::call_site()))
+ .unwrap_or(repr.clone());
+ let check_unknown = if attrs.serialize_deny_unknown {
+ quote! {
+ if value & !#all_variants != 0 {
+ use #serde::de::Error;
+ return #core::prelude::v1::Err(
+ D::Error::custom("enumset contains unknown bits")
+ )
}
}
- impl #core::cmp::Eq for #name { }
- impl #core::hash::Hash for #name {
- fn hash<H: #core::hash::Hasher>(&self, state: &mut H) {
- state.write_u8(*self as u8)
- }
+ } else {
+ quote! { }
+ };
+ quote! {
+ fn serialize<S: #serde::Serializer>(
+ set: ::enumset::EnumSet<#name>, ser: S,
+ ) -> #core::result::Result<S::Ok, S::Error> {
+ #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
}
- impl #core::clone::Clone for #name {
- fn clone(&self) -> Self {
- *self
- }
+ fn deserialize<'de, D: #serde::Deserializer<'de>>(
+ de: D,
+ ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
+ let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
+ #check_unknown
+ #core::prelude::v1::Ok(::enumset::EnumSet {
+ __enumset_underlying: (value & #all_variants) as #repr,
+ })
}
- impl #core::marker::Copy for #name { }
}
};
- quote! {
- unsafe impl ::enumset::EnumSetType for #name {
- type Repr = #repr;
- const ALL_BITS: Self::Repr = #all_variants_tt;
+ #[cfg(not(feature = "serde"))]
+ let serde_ops = quote! { };
+ let into_impl = if is_uninhabited {
+ quote! {
+ fn enum_into_u8(self) -> u8 {
+ panic!(concat!(stringify!(#name), " is uninhabited."))
+ }
+ unsafe fn enum_from_u8(val: u8) -> Self {
+ panic!(concat!(stringify!(#name), " is uninhabited."))
+ }
+ }
+ } else if is_zst {
+ let variant = &variants[0];
+ quote! {
+ fn enum_into_u8(self) -> u8 {
+ self as u8
+ }
+ unsafe fn enum_from_u8(val: u8) -> Self {
+ #name::#variant
+ }
+ }
+ } else {
+ quote! {
fn enum_into_u8(self) -> u8 {
self as u8
}
#core::mem::transmute(val)
}
}
+ };
+
+ let eq_impl = if is_uninhabited {
+ quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
+ } else {
+ quote!((*self as u8) == (*other as u8))
+ };
+
+ quote! {
+ unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
+ type Repr = #repr;
+ const ALL_BITS: Self::Repr = #all_variants;
+ #into_impl
+ #serde_ops
+ }
+
+ unsafe impl ::enumset::EnumSetType for #name { }
+
+ impl #core::cmp::PartialEq for #name {
+ fn eq(&self, other: &Self) -> bool {
+ #eq_impl
+ }
+ }
+ impl #core::cmp::Eq for #name { }
+ impl #core::clone::Clone for #name {
+ fn clone(&self) -> Self {
+ *self
+ }
+ }
+ impl #core::marker::Copy for #name { }
#ops
- #derives
}
}
-#[proc_macro_derive(EnumSetType, attributes(enumset_no_ops, enumset_no_derives))]
+#[derive(FromDeriveInput, Default)]
+#[darling(attributes(enumset), default)]
+struct EnumsetAttrs {
+ no_ops: bool,
+ serialize_as_list: bool,
+ serialize_deny_unknown: bool,
+ #[darling(default)]
+ serialize_repr: Option<String>,
+}
+
+#[proc_macro_derive(EnumSetType, attributes(enumset))]
pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(input);
- if let Data::Enum(data) = input.data {
+ if let Data::Enum(data) = &input.data {
if !input.generics.params.is_empty() {
error(input.generics.span(),
"`#[derive(EnumSetType)]` cannot be used on enums with type parameters.")
let mut max_variant = 0;
let mut current_variant = 0;
let mut has_manual_discriminant = false;
+ let mut variants = Vec::new();
for variant in &data.variants {
if let Fields::Unit = variant.fields {
if let Some((_, expr)) = &variant.discriminant {
if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
- current_variant = i.value();
+ current_variant = match i.base10_parse() {
+ Ok(val) => val,
+ Err(_) => return error(expr.span(), "Error parsing discriminant."),
+ };
has_manual_discriminant = true;
} else {
return error(variant.span(), "Unrecognized discriminant for variant.")
max_variant = current_variant
}
+ variants.push(variant.ident.clone());
current_variant += 1;
} else {
return error(variant.span(),
}
}
- let repr = Ident::new(if max_variant <= 8 {
+ let repr = Ident::new(if max_variant <= 7 {
"u8"
- } else if max_variant <= 16 {
+ } else if max_variant <= 15 {
"u16"
- } else if max_variant <= 32 {
+ } else if max_variant <= 31 {
"u32"
- } else if max_variant <= 64 {
+ } else if max_variant <= 63 {
"u64"
- } else if max_variant <= 128 {
+ } else if max_variant <= 127 {
"u128"
} else {
- panic!("max_variant > 128?")
+ panic!("max_variant > 127?")
}, Span::call_site());
- let mut no_ops = false;
- let mut no_derives = false;
-
- for attr in &input.attrs {
- let span = attr.span();
- let Attribute { tts, path: Path { segments, ..}, .. } = attr;
+ let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
+ Ok(attrs) => attrs,
+ Err(e) => return e.write_errors().into(),
+ };
- if segments.len() == 1 && segments[0].ident.to_string() == "enumset_no_ops" {
- no_ops = true;
- if !tts.is_empty() {
- return error(span, "`#[enumset_no_ops]` takes no arguments.")
- }
+ match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
+ Some("u8") => if max_variant > 7 {
+ return error(input.span(), "Too many variants for u8 serialization repr.")
}
- if segments.len() == 1 && segments[0].ident.to_string() == "enumset_no_derives" {
- no_derives = true;
- if !tts.is_empty() {
- return error(span, "`#[enumset_no_derives]` takes no arguments.")
- }
+ Some("u16") => if max_variant > 15 {
+ return error(input.span(), "Too many variants for u16 serialization repr.")
}
- }
+ Some("u32") => if max_variant > 31 {
+ return error(input.span(), "Too many variants for u32 serialization repr.")
+ }
+ Some("u64") => if max_variant > 63 {
+ return error(input.span(), "Too many variants for u64 serialization repr.")
+ }
+ Some("u128") => if max_variant > 127 {
+ return error(input.span(), "Too many variants for u128 serialization repr.")
+ }
+ None => { }
+ Some(x) => return error(input.span(),
+ &format!("{} is not a valid serialization repr.", x)),
+ };
- enum_set_type_impl(
- &input.ident, all_variants, repr, no_ops, no_derives,
- ).into()
+ enum_set_type_impl(&input.ident, all_variants, repr, attrs, variants).into()
}
} else {
error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")