]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
a25e94eab5f701e3634461c8bce1d77bb44e66f2
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal, Span};
8 use std::collections::HashSet;
9 use syn::{*, Result, Error};
10 use syn::spanned::Spanned;
11 use quote::*;
12
13 /// Helper function for emitting compile errors.
14 fn error<T>(span: Span, message: &str) -> Result<T> {
15     Err(Error::new(span, message))
16 }
17
18 /// Decodes the custom attributes for our custom derive.
19 #[derive(FromDeriveInput, Default)]
20 #[darling(attributes(enumset), default)]
21 struct EnumsetAttrs {
22     no_ops: bool,
23     serialize_as_list: bool,
24     serialize_deny_unknown: bool,
25     #[darling(default)]
26     serialize_repr: Option<String>,
27     #[darling(default)]
28     crate_name: Option<String>,
29 }
30
31 /// An variant in the enum set type.
32 struct EnumSetValue {
33     /// The name of the variant.
34     name: Ident,
35     /// The discriminant of the variant.
36     variant_repr: u32,
37 }
38
39 /// Stores information about the enum set type.
40 #[allow(dead_code)]
41 struct EnumSetInfo {
42     /// The name of the enum.
43     name: Ident,
44     /// The crate name to use.
45     crate_name: Option<Ident>,
46     /// The numeric type to serialize the enum as.
47     explicit_serde_repr: Option<Ident>,
48     /// Whether the underlying repr of the enum supports negative values.
49     has_signed_repr: bool,
50     /// Whether the underlying repr of the enum supports values higher than 2^32.
51     has_large_repr: bool,
52     /// A list of variants in the enum.
53     variants: Vec<EnumSetValue>,
54
55     /// The highest encountered variant discriminant.
56     max_discrim: u32,
57     /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
58     cur_discrim: u32,
59     /// A list of variant names that are already in use.
60     used_variant_names: HashSet<String>,
61     /// A list of variant discriminants that are already in use.
62     used_discriminants: HashSet<u32>,
63
64     /// Avoid generating operator overloads on the enum type.
65     no_ops: bool,
66     /// Serialize the enum as a list.
67     serialize_as_list: bool,
68     /// Disallow unknown bits while deserializing the enum.
69     serialize_deny_unknown: bool,
70 }
71 impl EnumSetInfo {
72     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
73         EnumSetInfo {
74             name: input.ident.clone(),
75             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
76             explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
77             has_signed_repr: false,
78             has_large_repr: false,
79             variants: Vec::new(),
80             max_discrim: 0,
81             cur_discrim: 0,
82             used_variant_names: HashSet::new(),
83             used_discriminants: HashSet::new(),
84             no_ops: attrs.no_ops,
85             serialize_as_list: attrs.serialize_as_list,
86             serialize_deny_unknown: attrs.serialize_deny_unknown
87         }
88     }
89
90     /// Sets an explicit repr for the enumset.
91     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
92         // Check whether the repr is supported, and if so, set some flags for better error
93         // messages later on.
94         match repr {
95             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
96             "usize" | "u64" | "u128" => {
97                 self.has_large_repr = true;
98                 Ok(())
99             }
100             "i8" | "i16" | "i32" => {
101                 self.has_signed_repr = true;
102                 Ok(())
103             }
104             "isize" | "i64" | "i128" => {
105                 self.has_signed_repr = true;
106                 self.has_large_repr = true;
107                 Ok(())
108             }
109             _ => error(attr_span, "Unsupported repr.")
110         }
111     }
112     /// Adds a variant to the enumset.
113     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
114         if self.used_variant_names.contains(&variant.ident.to_string()) {
115             error(variant.span(), "Duplicated variant name.")
116         } else if let Fields::Unit = variant.fields {
117             // Parse the discriminant.
118             if let Some((_, expr)) = &variant.discriminant {
119                 let discriminant_fail_message = format!(
120                     "Enum set discriminants must be `u32`s.{}",
121                     if self.has_signed_repr || self.has_large_repr {
122                         format!(
123                             " ({} discrimiants are still unsupported with reprs that allow them.)",
124                             if self.has_large_repr {
125                                 "larger"
126                             } else if self.has_signed_repr {
127                                 "negative"
128                             } else {
129                                 "larger or negative"
130                             }
131                         )
132                     } else {
133                         String::new()
134                     },
135                 );
136                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
137                     match i.base10_parse() {
138                         Ok(val) => self.cur_discrim = val,
139                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
140                     }
141                 } else {
142                     error(variant.span(), &discriminant_fail_message)?;
143                 }
144             }
145
146             // Validate the discriminant.
147             let discriminant = self.cur_discrim;
148             if discriminant >= 128 {
149                 let message = if self.variants.len() <= 127 {
150                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
151                 } else {
152                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
153                 };
154                 error(variant.span(), message)?;
155             }
156             if self.used_discriminants.contains(&discriminant) {
157                 error(variant.span(), "Duplicated enum discriminant.")?;
158             }
159
160             // Add the variant to the info.
161             self.cur_discrim += 1;
162             if discriminant > self.max_discrim {
163                 self.max_discrim = discriminant;
164             }
165             self.variants.push(EnumSetValue {
166                 name: variant.ident.clone(),
167                 variant_repr: discriminant,
168             });
169             self.used_variant_names.insert(variant.ident.to_string());
170             self.used_discriminants.insert(discriminant);
171
172             Ok(())
173         } else {
174             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
175         }
176     }
177     /// Validate the enumset type.
178     fn validate(&self) -> Result<()> {
179         // Check if all bits of the bitset can fit in the serialization representation.
180         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
181             let is_overflowed = match explicit_serde_repr.to_string().as_str() {
182                 "u8" => self.max_discrim >= 8,
183                 "u16" => self.max_discrim >= 16,
184                 "u32" => self.max_discrim >= 32,
185                 "u64" => self.max_discrim >= 64,
186                 "u128" => self.max_discrim >= 128,
187                 _ => error(
188                     Span::call_site(),
189                     "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
190                 )?,
191             };
192             if is_overflowed {
193                 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
194             }
195         }
196         Ok(())
197     }
198
199     /// Computes the underlying type used to store the enumset.
200     fn enumset_repr(&self) -> SynTokenStream {
201         if self.max_discrim <= 7 {
202             quote! { u8 }
203         } else if self.max_discrim <= 15 {
204             quote! { u16 }
205         } else if self.max_discrim <= 31 {
206             quote! { u32 }
207         } else if self.max_discrim <= 63 {
208             quote! { u64 }
209         } else if self.max_discrim <= 127 {
210             quote! { u128 }
211         } else {
212             panic!("max_variant > 127?")
213         }
214     }
215     /// Computes the underlying type used to serialize the enumset.
216     #[cfg(feature = "serde")]
217     fn serde_repr(&self) -> SynTokenStream {
218         if let Some(serde_repr) = &self.explicit_serde_repr {
219             quote! { #serde_repr }
220         } else {
221             self.enumset_repr()
222         }
223     }
224
225     /// Returns a bitmask of all variants in the set.
226     fn all_variants(&self) -> u128 {
227         let mut accum = 0u128;
228         for variant in &self.variants {
229             assert!(variant.variant_repr <= 127);
230             accum |= 1u128 << variant.variant_repr as u128;
231         }
232         accum
233     }
234 }
235
236 /// Generates the actual `EnumSetType` impl.
237 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
238     let name = &info.name;
239     let enumset = match &info.crate_name {
240         Some(crate_name) => quote!(::#crate_name),
241         None => quote!(::enumset),
242     };
243     let typed_enumset = quote!(#enumset::EnumSet<#name>);
244     let core = quote!(#enumset::__internal::core_export);
245
246     let repr = info.enumset_repr();
247     let all_variants = Literal::u128_unsuffixed(info.all_variants());
248
249     let ops = if info.no_ops {
250         quote! {}
251     } else {
252         quote! {
253             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
254                 type Output = #typed_enumset;
255                 fn sub(self, other: O) -> Self::Output {
256                     #enumset::EnumSet::only(self) - other.into()
257                 }
258             }
259             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
260                 type Output = #typed_enumset;
261                 fn bitand(self, other: O) -> Self::Output {
262                     #enumset::EnumSet::only(self) & other.into()
263                 }
264             }
265             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
266                 type Output = #typed_enumset;
267                 fn bitor(self, other: O) -> Self::Output {
268                     #enumset::EnumSet::only(self) | other.into()
269                 }
270             }
271             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
272                 type Output = #typed_enumset;
273                 fn bitxor(self, other: O) -> Self::Output {
274                     #enumset::EnumSet::only(self) ^ other.into()
275                 }
276             }
277             impl #core::ops::Not for #name {
278                 type Output = #typed_enumset;
279                 fn not(self) -> Self::Output {
280                     !#enumset::EnumSet::only(self)
281                 }
282             }
283             impl #core::cmp::PartialEq<#typed_enumset> for #name {
284                 fn eq(&self, other: &#typed_enumset) -> bool {
285                     #enumset::EnumSet::only(*self) == *other
286                 }
287             }
288         }
289     };
290
291
292     #[cfg(feature = "serde")]
293     let serde = quote!(#enumset::__internal::serde);
294
295     #[cfg(feature = "serde")]
296     let serde_ops = if info.serialize_as_list {
297         let expecting_str = format!("a list of {}", name);
298         quote! {
299             fn serialize<S: #serde::Serializer>(
300                 set: #enumset::EnumSet<#name>, ser: S,
301             ) -> #core::result::Result<S::Ok, S::Error> {
302                 use #serde::ser::SerializeSeq;
303                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
304                 for bit in set {
305                     seq.serialize_element(&bit)?;
306                 }
307                 seq.end()
308             }
309             fn deserialize<'de, D: #serde::Deserializer<'de>>(
310                 de: D,
311             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
312                 struct Visitor;
313                 impl <'de> #serde::de::Visitor<'de> for Visitor {
314                     type Value = #enumset::EnumSet<#name>;
315                     fn expecting(
316                         &self, formatter: &mut #core::fmt::Formatter,
317                     ) -> #core::fmt::Result {
318                         write!(formatter, #expecting_str)
319                     }
320                     fn visit_seq<A>(
321                         mut self, mut seq: A,
322                     ) -> #core::result::Result<Self::Value, A::Error> where
323                         A: #serde::de::SeqAccess<'de>
324                     {
325                         let mut accum = #enumset::EnumSet::<#name>::new();
326                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
327                             accum |= val;
328                         }
329                         #core::prelude::v1::Ok(accum)
330                     }
331                 }
332                 de.deserialize_seq(Visitor)
333             }
334         }
335     } else {
336         let serialize_repr = info.serde_repr();
337         let check_unknown = if info.serialize_deny_unknown {
338             quote! {
339                 if value & !#all_variants != 0 {
340                     use #serde::de::Error;
341                     return #core::prelude::v1::Err(
342                         D::Error::custom("enumset contains unknown bits")
343                     )
344                 }
345             }
346         } else {
347             quote! { }
348         };
349         quote! {
350             fn serialize<S: #serde::Serializer>(
351                 set: #enumset::EnumSet<#name>, ser: S,
352             ) -> #core::result::Result<S::Ok, S::Error> {
353                 #serde::Serialize::serialize(&(set.__priv_repr as #serialize_repr), ser)
354             }
355             fn deserialize<'de, D: #serde::Deserializer<'de>>(
356                 de: D,
357             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
358                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
359                 #check_unknown
360                 #core::prelude::v1::Ok(#enumset::EnumSet {
361                     __priv_repr: (value & #all_variants) as #repr,
362                 })
363             }
364         }
365     };
366
367     #[cfg(not(feature = "serde"))]
368     let serde_ops = quote! { };
369
370     let is_uninhabited = info.variants.is_empty();
371     let is_zst = info.variants.len() == 1;
372     let into_impl = if is_uninhabited {
373         quote! {
374             fn enum_into_u32(self) -> u32 {
375                 panic!(concat!(stringify!(#name), " is uninhabited."))
376             }
377             unsafe fn enum_from_u32(val: u32) -> Self {
378                 panic!(concat!(stringify!(#name), " is uninhabited."))
379             }
380         }
381     } else if is_zst {
382         let variant = &info.variants[0].name;
383         quote! {
384             fn enum_into_u32(self) -> u32 {
385                 self as u32
386             }
387             unsafe fn enum_from_u32(val: u32) -> Self {
388                 #name::#variant
389             }
390         }
391     } else {
392         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
393         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
394
395         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
396             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
397         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
398             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
399
400         quote! {
401             fn enum_into_u32(self) -> u32 {
402                 self as u32
403             }
404             unsafe fn enum_from_u32(val: u32) -> Self {
405                 // We put these in const fields so the branches they guard aren't generated even
406                 // on -O0
407                 #(const #const_field: bool =
408                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
409                 match val {
410                     // Every valid variant value has an explicit branch. If they get optimized out,
411                     // great. If the representation has changed somehow, and they don't, oh well,
412                     // there's still no UB.
413                     #(#variant_value => #name::#variant_name,)*
414                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
415                     // still unreachable.
416                     #(x if #const_field => {
417                         let x = x as #int_type;
418                         *(&x as *const _ as *const #name)
419                     })*
420                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
421                     // transmute, but, oh well.
422                     _ => #core::hint::unreachable_unchecked(),
423                 }
424             }
425         }
426     };
427
428     let eq_impl = if is_uninhabited {
429         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
430     } else {
431         quote!((*self as u32) == (*other as u32))
432     };
433
434     // used in the enum_set! macro `const fn`s.
435     let self_as_repr_mask = if is_uninhabited {
436         quote! { 0 } // impossible anyway
437     } else {
438         quote! { 1 << self as #repr }
439     };
440
441     quote! {
442         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
443             type Repr = #repr;
444             const ALL_BITS: Self::Repr = #all_variants;
445             #into_impl
446             #serde_ops
447         }
448
449         unsafe impl #enumset::EnumSetType for #name { }
450
451         impl #core::cmp::PartialEq for #name {
452             fn eq(&self, other: &Self) -> bool {
453                 #eq_impl
454             }
455         }
456         impl #core::cmp::Eq for #name { }
457         impl #core::clone::Clone for #name {
458             fn clone(&self) -> Self {
459                 *self
460             }
461         }
462         impl #core::marker::Copy for #name { }
463
464         impl #name {
465             /// Creates a new enumset with only this variant.
466             #[deprecated(note = "This method is an internal implementation detail generated by \
467                                  the `enumset` crate's procedural macro. It should not be used \
468                                  directly. Use `EnumSet::only` instead.")]
469             #[doc(hidden)]
470             pub const fn __impl_enumset_internal__const_only(self) -> #enumset::EnumSet<#name> {
471                 #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
472             }
473
474             /// Creates a new enumset with this variant added.
475             #[deprecated(note = "This method is an internal implementation detail generated by \
476                                  the `enumset` crate's procedural macro. It should not be used \
477                                  directly. Use the `|` operator instead.")]
478             #[doc(hidden)]
479             pub const fn __impl_enumset_internal__const_merge(
480                 self, chain: #enumset::EnumSet<#name>,
481             ) -> #enumset::EnumSet<#name> {
482                 #enumset::EnumSet { __priv_repr: chain.__priv_repr | #self_as_repr_mask }
483             }
484         }
485
486         #ops
487     }
488 }
489
490 /// A wrapper that parses the input enum.
491 #[proc_macro_derive(EnumSetType, attributes(enumset))]
492 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
493     let input: DeriveInput = parse_macro_input!(input);
494     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
495         Ok(attrs) => attrs,
496         Err(e) => return e.write_errors().into(),
497     };
498     match derive_enum_set_type_0(input, attrs) {
499         Ok(v) => v,
500         Err(e) => e.to_compile_error().into(),
501     }
502 }
503 fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
504     if !input.generics.params.is_empty() {
505         error(
506             input.generics.span(),
507             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
508         )
509     } else if let Data::Enum(data) = &input.data {
510         let mut info = EnumSetInfo::new(&input, attrs);
511         for attr in &input.attrs {
512             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
513                 let meta: Ident = attr.parse_args()?;
514                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
515             }
516         }
517         for variant in &data.variants {
518             info.push_variant(variant)?;
519         }
520         info.validate()?;
521         Ok(enum_set_type_impl(info).into())
522     } else {
523         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
524     }
525 }