1 #![recursion_limit="256"]
2 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
4 // TODO: Read #[repr(...)] attributes.
6 extern crate proc_macro;
9 use proc_macro::TokenStream;
10 use proc_macro2::{TokenStream as SynTokenStream, Literal};
11 use syn::{*, Result, Error};
12 use syn::export::Span;
13 use syn::spanned::Spanned;
16 fn error<T>(span: Span, message: &str) -> Result<T> {
17 Err(Error::new(span, message))
20 fn enum_set_type_impl(
21 name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
24 let is_uninhabited = variants.is_empty();
25 let is_zst = variants.len() == 1;
27 let typed_enumset = quote!(::enumset::EnumSet<#name>);
28 let core = quote!(::enumset::internal::core_export);
29 #[cfg(feature = "serde")]
30 let serde = quote!(::enumset::internal::serde);
32 // proc_macro2 does not support creating u128 literals.
33 let all_variants = Literal::u128_unsuffixed(all_variants);
35 let ops = if attrs.no_ops {
39 impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
40 type Output = #typed_enumset;
41 fn sub(self, other: O) -> Self::Output {
42 ::enumset::EnumSet::only(self) - other.into()
45 impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
46 type Output = #typed_enumset;
47 fn bitand(self, other: O) -> Self::Output {
48 ::enumset::EnumSet::only(self) & other.into()
51 impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
52 type Output = #typed_enumset;
53 fn bitor(self, other: O) -> Self::Output {
54 ::enumset::EnumSet::only(self) | other.into()
57 impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
58 type Output = #typed_enumset;
59 fn bitxor(self, other: O) -> Self::Output {
60 ::enumset::EnumSet::only(self) ^ other.into()
63 impl #core::ops::Not for #name {
64 type Output = #typed_enumset;
65 fn not(self) -> Self::Output {
66 !::enumset::EnumSet::only(self)
69 impl #core::cmp::PartialEq<#typed_enumset> for #name {
70 fn eq(&self, other: &#typed_enumset) -> bool {
71 ::enumset::EnumSet::only(*self) == *other
77 #[cfg(feature = "serde")]
78 let serde_ops = if attrs.serialize_as_list {
79 let expecting_str = format!("a list of {}", name);
81 fn serialize<S: #serde::Serializer>(
82 set: ::enumset::EnumSet<#name>, ser: S,
83 ) -> #core::result::Result<S::Ok, S::Error> {
84 use #serde::ser::SerializeSeq;
85 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
87 seq.serialize_element(&bit)?;
91 fn deserialize<'de, D: #serde::Deserializer<'de>>(
93 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
95 impl <'de> #serde::de::Visitor<'de> for Visitor {
96 type Value = ::enumset::EnumSet<#name>;
98 &self, formatter: &mut #core::fmt::Formatter,
99 ) -> #core::fmt::Result {
100 write!(formatter, #expecting_str)
103 mut self, mut seq: A,
104 ) -> #core::result::Result<Self::Value, A::Error> where
105 A: #serde::de::SeqAccess<'de>
107 let mut accum = ::enumset::EnumSet::<#name>::new();
108 while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
111 #core::prelude::v1::Ok(accum)
114 de.deserialize_seq(Visitor)
118 let serialize_repr = attrs.serialize_repr.as_ref()
119 .map(|x| Ident::new(&x, Span::call_site()))
120 .unwrap_or(repr.clone());
121 let check_unknown = if attrs.serialize_deny_unknown {
123 if value & !#all_variants != 0 {
124 use #serde::de::Error;
125 return #core::prelude::v1::Err(
126 D::Error::custom("enumset contains unknown bits")
134 fn serialize<S: #serde::Serializer>(
135 set: ::enumset::EnumSet<#name>, ser: S,
136 ) -> #core::result::Result<S::Ok, S::Error> {
137 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
139 fn deserialize<'de, D: #serde::Deserializer<'de>>(
141 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
142 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
144 #core::prelude::v1::Ok(::enumset::EnumSet {
145 __enumset_underlying: (value & #all_variants) as #repr,
151 #[cfg(not(feature = "serde"))]
152 let serde_ops = quote! { };
154 let into_impl = if is_uninhabited {
156 fn enum_into_u32(self) -> u32 {
157 panic!(concat!(stringify!(#name), " is uninhabited."))
159 unsafe fn enum_from_u32(val: u32) -> Self {
160 panic!(concat!(stringify!(#name), " is uninhabited."))
164 let variant = &variants[0];
166 fn enum_into_u32(self) -> u32 {
169 unsafe fn enum_from_u32(val: u32) -> Self {
175 fn enum_into_u32(self) -> u32 {
178 unsafe fn enum_from_u32(val: u32) -> Self {
179 #core::mem::transmute(val as #enum_repr)
184 let eq_impl = if is_uninhabited {
185 quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
187 quote!((*self as u32) == (*other as u32))
191 unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
193 const ALL_BITS: Self::Repr = #all_variants;
198 unsafe impl ::enumset::EnumSetType for #name { }
200 impl #core::cmp::PartialEq for #name {
201 fn eq(&self, other: &Self) -> bool {
205 impl #core::cmp::Eq for #name { }
206 impl #core::clone::Clone for #name {
207 fn clone(&self) -> Self {
211 impl #core::marker::Copy for #name { }
217 #[derive(FromDeriveInput, Default)]
218 #[darling(attributes(enumset), default)]
219 struct EnumsetAttrs {
221 serialize_as_list: bool,
222 serialize_deny_unknown: bool,
224 serialize_repr: Option<String>,
227 fn derive_enum_set_type_impl(input: DeriveInput) -> Result<TokenStream> {
228 if let Data::Enum(data) = &input.data {
229 if !input.generics.params.is_empty() {
231 input.generics.span(),
232 "`#[derive(EnumSetType)]` cannot be used on enums with type parameters."
235 let mut all_variants = 0u128;
236 let mut max_variant = 0u32;
237 let mut current_variant = 0u32;
238 let mut has_manual_discriminant = false;
239 let mut variants = Vec::new();
241 for variant in &data.variants {
242 if let Fields::Unit = variant.fields {
243 if let Some((_, expr)) = &variant.discriminant {
244 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
245 current_variant = match i.base10_parse() {
248 expr.span(), "Enum set discriminants must be `u32`s.",
251 has_manual_discriminant = true;
254 variant.span(), "Enum set discriminants must be `u32`s."
259 if current_variant >= 128 {
260 let message = if has_manual_discriminant {
261 "`#[derive(EnumSetType)]` currently only supports \
262 enum discriminants up to 127."
264 "`#[derive(EnumSetType)]` currently only supports \
265 enums up to 128 variants."
267 error(variant.span(), message)?;
270 if all_variants & (1 << current_variant as u128) != 0 {
273 &format!("Duplicate enum discriminant: {}", current_variant)
276 all_variants |= 1 << current_variant as u128;
277 if current_variant > max_variant {
278 max_variant = current_variant
281 variants.push(variant.ident.clone());
282 current_variant += 1;
286 "`#[derive(EnumSetType)]` can only be used on fieldless enums."
291 let repr = Ident::new(if max_variant <= 7 {
293 } else if max_variant <= 15 {
295 } else if max_variant <= 31 {
297 } else if max_variant <= 63 {
299 } else if max_variant <= 127 {
302 panic!("max_variant > 127?")
303 }, Span::call_site());
305 let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
307 Err(e) => return Ok(e.write_errors().into()),
310 let mut enum_repr = None;
311 for attr in &input.attrs {
312 if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
313 let meta: Ident = attr.parse_args()?;
314 if enum_repr.is_some() {
315 error(attr.span(), "Cannot duplicate #[repr(...)] annotations.")?;
317 let repr_max_variant = match meta.to_string().as_str() {
321 _ => error(attr.span(), "Only `u8`, `u16` and `u32` reprs are supported.")?,
323 if max_variant > repr_max_variant {
324 error(attr.span(), "A variant of this enum overflows its repr.")?;
326 enum_repr = Some(meta);
329 let enum_repr = enum_repr.unwrap_or_else(|| if max_variant < 0x100 {
330 Ident::new("u8", Span::call_site())
331 } else if max_variant < 0x10000 {
332 Ident::new("u16", Span::call_site())
334 Ident::new("u32", Span::call_site())
337 match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
338 Some("u8") => if max_variant > 7 {
339 error(input.span(), "Too many variants for u8 serialization repr.")?;
341 Some("u16") => if max_variant > 15 {
342 error(input.span(), "Too many variants for u16 serialization repr.")?;
344 Some("u32") => if max_variant > 31 {
345 error(input.span(), "Too many variants for u32 serialization repr.")?;
347 Some("u64") => if max_variant > 63 {
348 error(input.span(), "Too many variants for u64 serialization repr.")?;
350 Some("u128") => if max_variant > 127 {
351 error(input.span(), "Too many variants for u128 serialization repr.")?;
355 input.span(), &format!("{} is not a valid serialization repr.", x)
359 Ok(enum_set_type_impl(
360 &input.ident, all_variants, repr, attrs, variants, enum_repr,
364 error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
368 #[proc_macro_derive(EnumSetType, attributes(enumset))]
369 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
370 let input: DeriveInput = parse_macro_input!(input);
371 match derive_enum_set_type_impl(input) {
373 Err(e) => e.to_compile_error().into(),