1 #![recursion_limit="256"]
3 extern crate proc_macro;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal};
8 use std::collections::HashSet;
9 use syn::{*, Result, Error};
10 use syn::export::Span;
11 use syn::spanned::Spanned;
14 /// Helper function for emitting compile errors.
15 fn error<T>(span: Span, message: &str) -> Result<T> {
16 Err(Error::new(span, message))
19 /// Decodes the custom attributes for our custom derive.
20 #[derive(FromDeriveInput, Default)]
21 #[darling(attributes(enumset), default)]
24 serialize_as_list: bool,
25 serialize_deny_unknown: bool,
27 serialize_repr: Option<String>,
30 /// An variant in the enum set type.
36 /// Stores information about the enum set type.
40 explicit_serde_repr: Option<Ident>,
41 has_signed_repr: bool,
43 variants: Vec<EnumSetValue>,
47 used_variant_names: HashSet<String>,
48 used_discriminators: HashSet<u32>,
51 serialize_as_list: bool,
52 serialize_deny_unknown: bool,
55 fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
57 name: input.ident.clone(),
58 explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
59 has_signed_repr: false,
60 has_large_repr: false,
64 used_variant_names: HashSet::new(),
65 used_discriminators: HashSet::new(),
67 serialize_as_list: attrs.serialize_as_list,
68 serialize_deny_unknown: attrs.serialize_deny_unknown
72 fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
74 "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
75 "usize" | "u64" | "u128" => {
76 self.has_large_repr = true;
79 "i8" | "i16" | "i32" => {
80 self.has_signed_repr = true;
83 "isize" | "i64" | "i128" => {
84 self.has_signed_repr = true;
85 self.has_large_repr = true;
88 _ => error(attr_span, "Unsupported repr.")
91 fn push_variant(&mut self, variant: &Variant) -> Result<()> {
92 if self.used_variant_names.contains(&variant.ident.to_string()) {
93 error(variant.span(), "Duplicated variant name.")
94 } else if let Fields::Unit = variant.fields {
95 if let Some((_, expr)) = &variant.discriminant {
96 let discriminant_fail_message = format!(
97 "Enum set discriminants must be `u32`s.{}",
98 if self.has_signed_repr || self.has_large_repr {
100 " ({} discrimiants are still unsupported with reprs that allow them.)",
101 if self.has_large_repr {
103 } else if self.has_signed_repr {
113 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
114 match i.base10_parse() {
115 Ok(val) => self.cur_discrim = val,
116 Err(_) => error(expr.span(), &discriminant_fail_message)?,
119 error(variant.span(), &discriminant_fail_message)?;
123 let discriminant = self.cur_discrim;
124 if discriminant >= 128 {
125 let message = if self.variants.len() <= 127 {
126 "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
128 "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
130 error(variant.span(), message)?;
133 if self.used_discriminators.contains(&discriminant) {
134 error(variant.span(), "Duplicated enum discriminant.")?;
137 self.cur_discrim += 1;
138 if discriminant > self.max_discrim {
139 self.max_discrim = discriminant;
141 self.variants.push(EnumSetValue {
142 name: variant.ident.clone(),
143 variant_repr: discriminant,
145 self.used_variant_names.insert(variant.ident.to_string());
146 self.used_discriminators.insert(discriminant);
150 error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
153 fn validate(&self) -> Result<()> {
154 if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
155 let is_overflowed = match explicit_serde_repr.to_string().as_str() {
156 "u8" => self.max_discrim >= 8,
157 "u16" => self.max_discrim >= 16,
158 "u32" => self.max_discrim >= 32,
159 "u64" => self.max_discrim >= 64,
160 "u128" => self.max_discrim >= 128,
163 "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
167 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
173 fn enumset_repr(&self) -> SynTokenStream {
174 if self.max_discrim <= 7 {
176 } else if self.max_discrim <= 15 {
178 } else if self.max_discrim <= 31 {
180 } else if self.max_discrim <= 63 {
182 } else if self.max_discrim <= 127 {
185 panic!("max_variant > 127?")
188 #[cfg(feature = "serde")]
189 fn serde_repr(&self) -> SynTokenStream {
190 if let Some(serde_repr) = &self.explicit_serde_repr {
191 quote! { #serde_repr }
197 fn all_variants(&self) -> u128 {
198 let mut accum = 0u128;
199 for variant in &self.variants {
200 assert!(variant.variant_repr <= 127);
201 accum |= 1u128 << variant.variant_repr as u128;
207 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
208 let name = &info.name;
209 let typed_enumset = quote!(::enumset::EnumSet<#name>);
210 let core = quote!(::enumset::__internal::core_export);
212 let repr = info.enumset_repr();
213 let all_variants = Literal::u128_unsuffixed(info.all_variants());
215 let ops = if info.no_ops {
219 impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
220 type Output = #typed_enumset;
221 fn sub(self, other: O) -> Self::Output {
222 ::enumset::EnumSet::only(self) - other.into()
225 impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
226 type Output = #typed_enumset;
227 fn bitand(self, other: O) -> Self::Output {
228 ::enumset::EnumSet::only(self) & other.into()
231 impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
232 type Output = #typed_enumset;
233 fn bitor(self, other: O) -> Self::Output {
234 ::enumset::EnumSet::only(self) | other.into()
237 impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
238 type Output = #typed_enumset;
239 fn bitxor(self, other: O) -> Self::Output {
240 ::enumset::EnumSet::only(self) ^ other.into()
243 impl #core::ops::Not for #name {
244 type Output = #typed_enumset;
245 fn not(self) -> Self::Output {
246 !::enumset::EnumSet::only(self)
249 impl #core::cmp::PartialEq<#typed_enumset> for #name {
250 fn eq(&self, other: &#typed_enumset) -> bool {
251 ::enumset::EnumSet::only(*self) == *other
258 #[cfg(feature = "serde")]
259 let serde = quote!(::enumset::__internal::serde);
261 #[cfg(feature = "serde")]
262 let serde_ops = if info.serialize_as_list {
263 let expecting_str = format!("a list of {}", name);
265 fn serialize<S: #serde::Serializer>(
266 set: ::enumset::EnumSet<#name>, ser: S,
267 ) -> #core::result::Result<S::Ok, S::Error> {
268 use #serde::ser::SerializeSeq;
269 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
271 seq.serialize_element(&bit)?;
275 fn deserialize<'de, D: #serde::Deserializer<'de>>(
277 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
279 impl <'de> #serde::de::Visitor<'de> for Visitor {
280 type Value = ::enumset::EnumSet<#name>;
282 &self, formatter: &mut #core::fmt::Formatter,
283 ) -> #core::fmt::Result {
284 write!(formatter, #expecting_str)
287 mut self, mut seq: A,
288 ) -> #core::result::Result<Self::Value, A::Error> where
289 A: #serde::de::SeqAccess<'de>
291 let mut accum = ::enumset::EnumSet::<#name>::new();
292 while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
295 #core::prelude::v1::Ok(accum)
298 de.deserialize_seq(Visitor)
302 let serialize_repr = info.serde_repr();
303 let check_unknown = if info.serialize_deny_unknown {
305 if value & !#all_variants != 0 {
306 use #serde::de::Error;
307 return #core::prelude::v1::Err(
308 D::Error::custom("enumset contains unknown bits")
316 fn serialize<S: #serde::Serializer>(
317 set: ::enumset::EnumSet<#name>, ser: S,
318 ) -> #core::result::Result<S::Ok, S::Error> {
319 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
321 fn deserialize<'de, D: #serde::Deserializer<'de>>(
323 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
324 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
326 #core::prelude::v1::Ok(::enumset::EnumSet {
327 __enumset_underlying: (value & #all_variants) as #repr,
333 #[cfg(not(feature = "serde"))]
334 let serde_ops = quote! { };
336 let is_uninhabited = info.variants.is_empty();
337 let is_zst = info.variants.len() == 1;
338 let into_impl = if is_uninhabited {
340 fn enum_into_u32(self) -> u32 {
341 panic!(concat!(stringify!(#name), " is uninhabited."))
343 unsafe fn enum_from_u32(val: u32) -> Self {
344 panic!(concat!(stringify!(#name), " is uninhabited."))
348 let variant = &info.variants[0].name;
350 fn enum_into_u32(self) -> u32 {
353 unsafe fn enum_from_u32(val: u32) -> Self {
358 let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
359 let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
361 let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
362 .iter().map(|x| Ident::new(x, Span::call_site())).collect();
363 let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
364 .iter().map(|x| Ident::new(x, Span::call_site())).collect();
367 fn enum_into_u32(self) -> u32 {
370 unsafe fn enum_from_u32(val: u32) -> Self {
371 // We put these in const fields so they aren't generated even on -O0
372 #(const #const_field: bool =
373 #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
375 // Every valid variant value has an explicit branch. If they get optimized out,
376 // great. If the representation has changed somehow, and they don't, oh well,
377 // there's still no UB.
378 #(#variant_value => #name::#variant_name,)*
379 // Helps hint to the LLVM that this is a transmute. Note that this branch is
380 // still unreachable.
381 #(x if #const_field => {
382 let x = x as #int_type;
383 *(&x as *const _ as *const #name)
385 // Default case. Sometimes causes LLVM to generate a table instead of a simple
386 // transmute, but, oh well.
387 _ => #core::hint::unreachable_unchecked(),
393 let eq_impl = if is_uninhabited {
394 quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
396 quote!((*self as u32) == (*other as u32))
400 unsafe impl ::enumset::__internal::EnumSetTypePrivate for #name {
402 const ALL_BITS: Self::Repr = #all_variants;
407 unsafe impl ::enumset::EnumSetType for #name { }
409 impl #core::cmp::PartialEq for #name {
410 fn eq(&self, other: &Self) -> bool {
414 impl #core::cmp::Eq for #name { }
415 impl #core::clone::Clone for #name {
416 fn clone(&self) -> Self {
420 impl #core::marker::Copy for #name { }
426 fn derive_enum_set_type_impl(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
427 if !input.generics.params.is_empty() {
429 input.generics.span(),
430 "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
432 } else if let Data::Enum(data) = &input.data {
433 let mut info = EnumSetInfo::new(&input, attrs);
434 for attr in &input.attrs {
435 if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
436 let meta: Ident = attr.parse_args()?;
437 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
440 for variant in &data.variants {
441 info.push_variant(variant)?;
444 Ok(enum_set_type_impl(info).into())
446 error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
450 #[proc_macro_derive(EnumSetType, attributes(enumset))]
451 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
452 let input: DeriveInput = parse_macro_input!(input);
453 let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
455 Err(e) => return e.write_errors().into(),
457 match derive_enum_set_type_impl(input, attrs) {
459 Err(e) => e.to_compile_error().into(),