use darling::*;
use proc_macro::TokenStream;
use proc_macro2::{TokenStream as SynTokenStream, Literal};
+use std::collections::HashSet;
use syn::{*, Result, Error};
use syn::export::Span;
use syn::spanned::Spanned;
use quote::*;
+/// Helper function for emitting compile errors.
fn error<T>(span: Span, message: &str) -> Result<T> {
Err(Error::new(span, message))
}
+/// Decodes the custom attributes for our custom derive.
+#[derive(FromDeriveInput, Default)]
+#[darling(attributes(enumset), default)]
+struct EnumsetAttrs {
+ no_ops: bool,
+ serialize_as_list: bool,
+ serialize_deny_unknown: bool,
+ #[darling(default)]
+ serialize_repr: Option<String>,
+}
+
+/// An variant in the enum set type.
+struct EnumSetValue {
+ name: Ident,
+ variant_repr: u32,
+}
+
+/// Stores information about the enum set type.
+#[allow(dead_code)]
+struct EnumSetInfo {
+ name: Ident,
+ explicit_repr: Option<Ident>,
+ explicit_serde_repr: Option<Ident>,
+ has_signed_repr: bool,
+ has_large_repr: bool,
+ variants: Vec<EnumSetValue>,
+
+ max_discrim: u32,
+ cur_discrim: u32,
+ used_variant_names: HashSet<String>,
+ used_discriminators: HashSet<u32>,
+
+ no_ops: bool,
+ serialize_as_list: bool,
+ serialize_deny_unknown: bool,
+}
+impl EnumSetInfo {
+ fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
+ EnumSetInfo {
+ name: input.ident.clone(),
+ explicit_repr: None,
+ explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
+ has_signed_repr: false,
+ has_large_repr: false,
+ variants: Vec::new(),
+ max_discrim: 0,
+ cur_discrim: 0,
+ used_variant_names: HashSet::new(),
+ used_discriminators: HashSet::new(),
+ no_ops: attrs.no_ops,
+ serialize_as_list: attrs.serialize_as_list,
+ serialize_deny_unknown: attrs.serialize_deny_unknown
+ }
+ }
+
+ fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
+ if self.explicit_repr.is_some() {
+ error(attr_span, "Cannot duplicate #[repr(...)] annotations.")
+ } else {
+ self.explicit_repr = Some(Ident::new(match repr {
+ "Rust" | "C" => return Ok(()), // We assume default repr in these cases.
+ "u8" | "u16" | "u32" => repr,
+ "usize" | "u64" | "u128" => {
+ self.has_large_repr = true;
+ repr
+ }
+ "i8" | "i16" | "i32" => {
+ self.has_signed_repr = true;
+ repr
+ }
+ "isize" | "i64" | "i128" => {
+ self.has_signed_repr = true;
+ self.has_large_repr = true;
+ repr
+ }
+ _ => return error(attr_span, "Unsupported repr.")
+ }, Span::call_site()));
+ Ok(())
+ }
+ }
+ fn push_variant(&mut self, variant: &Variant) -> Result<()> {
+ if self.used_variant_names.contains(&variant.ident.to_string()) {
+ error(variant.span(), "Duplicated variant name.")
+ } else if let Fields::Unit = variant.fields {
+ if let Some((_, expr)) = &variant.discriminant {
+ let discriminant_fail_message = format!(
+ "Enum set discriminants must be `u32`s.{}",
+ if self.has_signed_repr || self.has_large_repr {
+ format!(
+ " ({} discrimiants are still unsupported with reprs that allow them.)",
+ if self.has_large_repr {
+ "larger"
+ } else if self.has_signed_repr {
+ "negative"
+ } else {
+ "larger or negative"
+ }
+ )
+ } else {
+ String::new()
+ },
+ );
+ if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
+ match i.base10_parse() {
+ Ok(val) => self.cur_discrim = val,
+ Err(_) => error(expr.span(), &discriminant_fail_message)?,
+ }
+ } else {
+ error(variant.span(), &discriminant_fail_message)?;
+ }
+ }
+
+ let discriminant = self.cur_discrim;
+ if discriminant >= 128 {
+ let message = if self.variants.len() <= 127 {
+ "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
+ } else {
+ "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
+ };
+ error(variant.span(), message)?;
+ }
+
+ if self.used_discriminators.contains(&discriminant) {
+ error(variant.span(), "Duplicated enum discriminant.")?;
+ }
+
+ self.cur_discrim += 1;
+ if discriminant > self.max_discrim {
+ self.max_discrim = discriminant;
+ }
+ self.variants.push(EnumSetValue {
+ name: variant.ident.clone(),
+ variant_repr: discriminant,
+ });
+ self.used_variant_names.insert(variant.ident.to_string());
+ self.used_discriminators.insert(discriminant);
+
+ Ok(())
+ } else {
+ error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
+ }
+ }
+ fn validate(&self) -> Result<()> {
+ if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
+ let is_overflowed = match explicit_serde_repr.to_string().as_str() {
+ "u8" => self.max_discrim >= 8,
+ "u16" => self.max_discrim >= 16,
+ "u32" => self.max_discrim >= 32,
+ "u64" => self.max_discrim >= 64,
+ "u128" => self.max_discrim >= 128,
+ _ => error(
+ Span::call_site(),
+ "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
+ )?,
+ };
+ if is_overflowed {
+ error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn enum_repr(&self) -> SynTokenStream {
+ if let Some(explicit_repr) = &self.explicit_repr {
+ quote! { #explicit_repr }
+ } else if self.max_discrim < 0x100 {
+ quote! { u8 }
+ } else if self.max_discrim < 0x10000 {
+ quote! { u16 }
+ } else {
+ quote! { u32 }
+ }
+ }
+ fn enumset_repr(&self) -> SynTokenStream {
+ if self.max_discrim <= 7 {
+ quote! { u8 }
+ } else if self.max_discrim <= 15 {
+ quote! { u16 }
+ } else if self.max_discrim <= 31 {
+ quote! { u32 }
+ } else if self.max_discrim <= 63 {
+ quote! { u64 }
+ } else if self.max_discrim <= 127 {
+ quote! { u128 }
+ } else {
+ panic!("max_variant > 127?")
+ }
+ }
+ #[cfg(feature = "serde")]
+ fn serde_repr(&self) -> SynTokenStream {
+ if let Some(serde_repr) = &self.explicit_serde_repr {
+ quote! { #serde_repr }
+ } else {
+ self.enumset_repr()
+ }
+ }
+
+ fn all_variants(&self) -> u128 {
+ let mut accum = 0u128;
+ for variant in &self.variants {
+ assert!(variant.variant_repr <= 127);
+ accum |= 1u128 << variant.variant_repr as u128;
+ }
+ accum
+ }
+}
+
fn enum_set_type_impl(
- name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
- enum_repr: Ident,
+ info: EnumSetInfo,
+// name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
+// enum_repr: Ident,
) -> SynTokenStream {
- let is_uninhabited = variants.is_empty();
- let is_zst = variants.len() == 1;
-
+ let name = &info.name;
let typed_enumset = quote!(::enumset::EnumSet<#name>);
let core = quote!(::enumset::internal::core_export);
- #[cfg(feature = "serde")]
- let serde = quote!(::enumset::internal::serde);
- // proc_macro2 does not support creating u128 literals.
- let all_variants = Literal::u128_unsuffixed(all_variants);
+ let repr = info.enumset_repr();
+ let enum_repr = info.enum_repr();
+ let all_variants = Literal::u128_unsuffixed(info.all_variants());
- let ops = if attrs.no_ops {
+ let ops = if info.no_ops {
quote! {}
} else {
quote! {
}
};
+
+ #[cfg(feature = "serde")]
+ let serde = quote!(::enumset::internal::serde);
+
#[cfg(feature = "serde")]
- let serde_ops = if attrs.serialize_as_list {
+ let serde_ops = if info.serialize_as_list {
let expecting_str = format!("a list of {}", name);
quote! {
fn serialize<S: #serde::Serializer>(
}
}
} else {
- let serialize_repr = attrs.serialize_repr.as_ref()
- .map(|x| Ident::new(&x, Span::call_site()))
- .unwrap_or(repr.clone());
- let check_unknown = if attrs.serialize_deny_unknown {
+ let serialize_repr = info.serde_repr();
+ let check_unknown = if info.serialize_deny_unknown {
quote! {
if value & !#all_variants != 0 {
use #serde::de::Error;
#[cfg(not(feature = "serde"))]
let serde_ops = quote! { };
+ let is_uninhabited = info.variants.is_empty();
+ let is_zst = info.variants.len() == 1;
let into_impl = if is_uninhabited {
quote! {
fn enum_into_u32(self) -> u32 {
}
}
} else if is_zst {
- let variant = &variants[0];
+ let variant = &info.variants[0].name;
quote! {
fn enum_into_u32(self) -> u32 {
self as u32
}
}
-#[derive(FromDeriveInput, Default)]
-#[darling(attributes(enumset), default)]
-struct EnumsetAttrs {
- no_ops: bool,
- serialize_as_list: bool,
- serialize_deny_unknown: bool,
- #[darling(default)]
- serialize_repr: Option<String>,
-}
-
-fn derive_enum_set_type_impl(input: DeriveInput) -> Result<TokenStream> {
- if let Data::Enum(data) = &input.data {
- if !input.generics.params.is_empty() {
- error(
- input.generics.span(),
- "`#[derive(EnumSetType)]` cannot be used on enums with type parameters."
- )
- } else {
- let mut all_variants = 0u128;
- let mut max_variant = 0u32;
- let mut current_variant = 0u32;
- let mut has_manual_discriminant = false;
- let mut variants = Vec::new();
-
- for variant in &data.variants {
- if let Fields::Unit = variant.fields {
- if let Some((_, expr)) = &variant.discriminant {
- if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
- current_variant = match i.base10_parse() {
- Ok(val) => val,
- Err(_) => error(
- expr.span(), "Enum set discriminants must be `u32`s.",
- )?,
- };
- has_manual_discriminant = true;
- } else {
- error(
- variant.span(), "Enum set discriminants must be `u32`s."
- )?;
- }
- }
-
- if current_variant >= 128 {
- let message = if has_manual_discriminant {
- "`#[derive(EnumSetType)]` currently only supports \
- enum discriminants up to 127."
- } else {
- "`#[derive(EnumSetType)]` currently only supports \
- enums up to 128 variants."
- };
- error(variant.span(), message)?;
- }
-
- if all_variants & (1 << current_variant as u128) != 0 {
- error(
- variant.span(),
- &format!("Duplicate enum discriminant: {}", current_variant)
- )?;
- }
- all_variants |= 1 << current_variant as u128;
- if current_variant > max_variant {
- max_variant = current_variant
- }
-
- variants.push(variant.ident.clone());
- current_variant += 1;
- } else {
- error(
- variant.span(),
- "`#[derive(EnumSetType)]` can only be used on fieldless enums."
- )?;
- }
- }
-
- let repr = Ident::new(if max_variant <= 7 {
- "u8"
- } else if max_variant <= 15 {
- "u16"
- } else if max_variant <= 31 {
- "u32"
- } else if max_variant <= 63 {
- "u64"
- } else if max_variant <= 127 {
- "u128"
- } else {
- panic!("max_variant > 127?")
- }, Span::call_site());
-
- let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
- Ok(attrs) => attrs,
- Err(e) => return Ok(e.write_errors().into()),
- };
-
- let mut enum_repr = None;
- for attr in &input.attrs {
- if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
- let meta: Ident = attr.parse_args()?;
- if enum_repr.is_some() {
- error(attr.span(), "Cannot duplicate #[repr(...)] annotations.")?;
- }
- let repr_max_variant = match meta.to_string().as_str() {
- "u8" => 0xFF,
- "u16" => 0xFFFF,
- "u32" => 0xFFFFFFFF,
- _ => error(attr.span(), "Only `u8`, `u16` and `u32` reprs are supported.")?,
- };
- if max_variant > repr_max_variant {
- error(attr.span(), "A variant of this enum overflows its repr.")?;
- }
- enum_repr = Some(meta);
- }
+fn derive_enum_set_type_impl(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
+ if !input.generics.params.is_empty() {
+ error(
+ input.generics.span(),
+ "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
+ )
+ } else if let Data::Enum(data) = &input.data {
+ let mut info = EnumSetInfo::new(&input, attrs);
+ for attr in &input.attrs {
+ if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
+ let meta: Ident = attr.parse_args()?;
+ info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
}
- let enum_repr = enum_repr.unwrap_or_else(|| if max_variant < 0x100 {
- Ident::new("u8", Span::call_site())
- } else if max_variant < 0x10000 {
- Ident::new("u16", Span::call_site())
- } else {
- Ident::new("u32", Span::call_site())
- });
-
- match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
- Some("u8") => if max_variant > 7 {
- error(input.span(), "Too many variants for u8 serialization repr.")?;
- }
- Some("u16") => if max_variant > 15 {
- error(input.span(), "Too many variants for u16 serialization repr.")?;
- }
- Some("u32") => if max_variant > 31 {
- error(input.span(), "Too many variants for u32 serialization repr.")?;
- }
- Some("u64") => if max_variant > 63 {
- error(input.span(), "Too many variants for u64 serialization repr.")?;
- }
- Some("u128") => if max_variant > 127 {
- error(input.span(), "Too many variants for u128 serialization repr.")?;
- }
- None => { }
- Some(x) => error(
- input.span(), &format!("{} is not a valid serialization repr.", x)
- )?,
- };
-
- Ok(enum_set_type_impl(
- &input.ident, all_variants, repr, attrs, variants, enum_repr,
- ).into())
}
+ for variant in &data.variants {
+ info.push_variant(variant)?;
+ }
+ info.validate()?;
+ Ok(enum_set_type_impl(info).into())
} else {
error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
}
#[proc_macro_derive(EnumSetType, attributes(enumset))]
pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(input);
- match derive_enum_set_type_impl(input) {
+ let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
+ Ok(attrs) => attrs,
+ Err(e) => return e.write_errors().into(),
+ };
+ match derive_enum_set_type_impl(input, attrs) {
Ok(v) => v,
Err(e) => e.to_compile_error().into(),
}