]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Add a new std feature for enumset. Fixes #34
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal, Span};
8 use std::collections::HashSet;
9 use syn::{*, Result, Error};
10 use syn::spanned::Spanned;
11 use quote::*;
12
13 /// Helper function for emitting compile errors.
14 fn error<T>(span: Span, message: &str) -> Result<T> {
15     Err(Error::new(span, message))
16 }
17
18 /// Decodes the custom attributes for our custom derive.
19 #[derive(FromDeriveInput, Default)]
20 #[darling(attributes(enumset), default)]
21 struct EnumsetAttrs {
22     no_ops: bool,
23     no_super_impls: bool,
24     serialize_as_list: bool,
25     serialize_deny_unknown: bool,
26     #[darling(default)]
27     serialize_repr: Option<String>,
28     #[darling(default)]
29     crate_name: Option<String>,
30 }
31
32 /// An variant in the enum set type.
33 struct EnumSetValue {
34     /// The name of the variant.
35     name: Ident,
36     /// The discriminant of the variant.
37     variant_repr: u32,
38 }
39
40 /// Stores information about the enum set type.
41 #[allow(dead_code)]
42 struct EnumSetInfo {
43     /// The name of the enum.
44     name: Ident,
45     /// The crate name to use.
46     crate_name: Option<Ident>,
47     /// The numeric type to serialize the enum as.
48     explicit_serde_repr: Option<Ident>,
49     /// Whether the underlying repr of the enum supports negative values.
50     has_signed_repr: bool,
51     /// Whether the underlying repr of the enum supports values higher than 2^32.
52     has_large_repr: bool,
53     /// A list of variants in the enum.
54     variants: Vec<EnumSetValue>,
55
56     /// The highest encountered variant discriminant.
57     max_discrim: u32,
58     /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
59     cur_discrim: u32,
60     /// A list of variant names that are already in use.
61     used_variant_names: HashSet<String>,
62     /// A list of variant discriminants that are already in use.
63     used_discriminants: HashSet<u32>,
64
65     /// Avoid generating operator overloads on the enum type.
66     no_ops: bool,
67     /// Avoid generating implementations for `Clone`, `Copy`, `Eq`, and `PartialEq`.
68     no_super_impls: bool,
69     /// Serialize the enum as a list.
70     serialize_as_list: bool,
71     /// Disallow unknown bits while deserializing the enum.
72     serialize_deny_unknown: bool,
73 }
74 impl EnumSetInfo {
75     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
76         EnumSetInfo {
77             name: input.ident.clone(),
78             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
79             explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
80             has_signed_repr: false,
81             has_large_repr: false,
82             variants: Vec::new(),
83             max_discrim: 0,
84             cur_discrim: 0,
85             used_variant_names: HashSet::new(),
86             used_discriminants: HashSet::new(),
87             no_ops: attrs.no_ops,
88             no_super_impls: attrs.no_super_impls,
89             serialize_as_list: attrs.serialize_as_list,
90             serialize_deny_unknown: attrs.serialize_deny_unknown
91         }
92     }
93
94     /// Sets an explicit repr for the enumset.
95     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
96         // Check whether the repr is supported, and if so, set some flags for better error
97         // messages later on.
98         match repr {
99             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
100             "usize" | "u64" | "u128" => {
101                 self.has_large_repr = true;
102                 Ok(())
103             }
104             "i8" | "i16" | "i32" => {
105                 self.has_signed_repr = true;
106                 Ok(())
107             }
108             "isize" | "i64" | "i128" => {
109                 self.has_signed_repr = true;
110                 self.has_large_repr = true;
111                 Ok(())
112             }
113             _ => error(attr_span, "Unsupported repr.")
114         }
115     }
116     /// Adds a variant to the enumset.
117     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
118         if self.used_variant_names.contains(&variant.ident.to_string()) {
119             error(variant.span(), "Duplicated variant name.")
120         } else if let Fields::Unit = variant.fields {
121             // Parse the discriminant.
122             if let Some((_, expr)) = &variant.discriminant {
123                 let discriminant_fail_message = format!(
124                     "Enum set discriminants must be `u32`s.{}",
125                     if self.has_signed_repr || self.has_large_repr {
126                         format!(
127                             " ({} discrimiants are still unsupported with reprs that allow them.)",
128                             if self.has_large_repr {
129                                 "larger"
130                             } else if self.has_signed_repr {
131                                 "negative"
132                             } else {
133                                 "larger or negative"
134                             }
135                         )
136                     } else {
137                         String::new()
138                     },
139                 );
140                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
141                     match i.base10_parse() {
142                         Ok(val) => self.cur_discrim = val,
143                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
144                     }
145                 } else {
146                     error(variant.span(), &discriminant_fail_message)?;
147                 }
148             }
149
150             // Validate the discriminant.
151             let discriminant = self.cur_discrim;
152             if discriminant >= 128 {
153                 let message = if self.variants.len() <= 127 {
154                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
155                 } else {
156                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
157                 };
158                 error(variant.span(), message)?;
159             }
160             if self.used_discriminants.contains(&discriminant) {
161                 error(variant.span(), "Duplicated enum discriminant.")?;
162             }
163
164             // Add the variant to the info.
165             self.cur_discrim += 1;
166             if discriminant > self.max_discrim {
167                 self.max_discrim = discriminant;
168             }
169             self.variants.push(EnumSetValue {
170                 name: variant.ident.clone(),
171                 variant_repr: discriminant,
172             });
173             self.used_variant_names.insert(variant.ident.to_string());
174             self.used_discriminants.insert(discriminant);
175
176             Ok(())
177         } else {
178             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
179         }
180     }
181     /// Validate the enumset type.
182     fn validate(&self) -> Result<()> {
183         // Check if all bits of the bitset can fit in the serialization representation.
184         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
185             let is_overflowed = match explicit_serde_repr.to_string().as_str() {
186                 "u8" => self.max_discrim >= 8,
187                 "u16" => self.max_discrim >= 16,
188                 "u32" => self.max_discrim >= 32,
189                 "u64" => self.max_discrim >= 64,
190                 "u128" => self.max_discrim >= 128,
191                 _ => error(
192                     Span::call_site(),
193                     "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
194                 )?,
195             };
196             if is_overflowed {
197                 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
198             }
199         }
200         Ok(())
201     }
202
203     /// Computes the underlying type used to store the enumset.
204     fn enumset_repr(&self) -> SynTokenStream {
205         if self.max_discrim <= 7 {
206             quote! { u8 }
207         } else if self.max_discrim <= 15 {
208             quote! { u16 }
209         } else if self.max_discrim <= 31 {
210             quote! { u32 }
211         } else if self.max_discrim <= 63 {
212             quote! { u64 }
213         } else if self.max_discrim <= 127 {
214             quote! { u128 }
215         } else {
216             panic!("max_variant > 127?")
217         }
218     }
219     /// Computes the underlying type used to serialize the enumset.
220     #[cfg(feature = "serde")]
221     fn serde_repr(&self) -> SynTokenStream {
222         if let Some(serde_repr) = &self.explicit_serde_repr {
223             quote! { #serde_repr }
224         } else {
225             self.enumset_repr()
226         }
227     }
228
229     /// Returns a bitmask of all variants in the set.
230     fn all_variants(&self) -> u128 {
231         let mut accum = 0u128;
232         for variant in &self.variants {
233             assert!(variant.variant_repr <= 127);
234             accum |= 1u128 << variant.variant_repr as u128;
235         }
236         accum
237     }
238 }
239
240 /// Generates the actual `EnumSetType` impl.
241 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
242     let name = &info.name;
243
244     let enumset = match &info.crate_name {
245         Some(crate_name) => quote!(::#crate_name),
246         None => {
247             #[cfg(feature = "proc-macro-crate")]
248             {
249                 use proc_macro_crate::FoundCrate;
250
251                 let crate_name = proc_macro_crate::crate_name("enumset");
252                 match crate_name {
253                     Ok(FoundCrate::Name(name)) => {
254                         let ident = Ident::new(&name, Span::call_site());
255                         quote!(::#ident)
256                     }
257                     _ => quote!(::enumset),
258                 }
259             }
260
261             #[cfg(not(feature = "proc-macro-crate"))]
262             {
263                 quote!(::enumset)
264             }
265         },
266     };
267     let typed_enumset = quote!(#enumset::EnumSet<#name>);
268     let core = quote!(#enumset::__internal::core_export);
269
270     let repr = info.enumset_repr();
271     let all_variants = Literal::u128_unsuffixed(info.all_variants());
272
273     let ops = if info.no_ops {
274         quote! {}
275     } else {
276         quote! {
277             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
278                 type Output = #typed_enumset;
279                 fn sub(self, other: O) -> Self::Output {
280                     #enumset::EnumSet::only(self) - other.into()
281                 }
282             }
283             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
284                 type Output = #typed_enumset;
285                 fn bitand(self, other: O) -> Self::Output {
286                     #enumset::EnumSet::only(self) & other.into()
287                 }
288             }
289             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
290                 type Output = #typed_enumset;
291                 fn bitor(self, other: O) -> Self::Output {
292                     #enumset::EnumSet::only(self) | other.into()
293                 }
294             }
295             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
296                 type Output = #typed_enumset;
297                 fn bitxor(self, other: O) -> Self::Output {
298                     #enumset::EnumSet::only(self) ^ other.into()
299                 }
300             }
301             impl #core::ops::Not for #name {
302                 type Output = #typed_enumset;
303                 fn not(self) -> Self::Output {
304                     !#enumset::EnumSet::only(self)
305                 }
306             }
307             impl #core::cmp::PartialEq<#typed_enumset> for #name {
308                 fn eq(&self, other: &#typed_enumset) -> bool {
309                     #enumset::EnumSet::only(*self) == *other
310                 }
311             }
312         }
313     };
314
315
316     #[cfg(feature = "serde")]
317     let serde = quote!(#enumset::__internal::serde);
318
319     #[cfg(feature = "serde")]
320     let serde_ops = if info.serialize_as_list {
321         let expecting_str = format!("a list of {}", name);
322         quote! {
323             fn serialize<S: #serde::Serializer>(
324                 set: #enumset::EnumSet<#name>, ser: S,
325             ) -> #core::result::Result<S::Ok, S::Error> {
326                 use #serde::ser::SerializeSeq;
327                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
328                 for bit in set {
329                     seq.serialize_element(&bit)?;
330                 }
331                 seq.end()
332             }
333             fn deserialize<'de, D: #serde::Deserializer<'de>>(
334                 de: D,
335             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
336                 struct Visitor;
337                 impl <'de> #serde::de::Visitor<'de> for Visitor {
338                     type Value = #enumset::EnumSet<#name>;
339                     fn expecting(
340                         &self, formatter: &mut #core::fmt::Formatter,
341                     ) -> #core::fmt::Result {
342                         write!(formatter, #expecting_str)
343                     }
344                     fn visit_seq<A>(
345                         mut self, mut seq: A,
346                     ) -> #core::result::Result<Self::Value, A::Error> where
347                         A: #serde::de::SeqAccess<'de>
348                     {
349                         let mut accum = #enumset::EnumSet::<#name>::new();
350                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
351                             accum |= val;
352                         }
353                         #core::prelude::v1::Ok(accum)
354                     }
355                 }
356                 de.deserialize_seq(Visitor)
357             }
358         }
359     } else {
360         let serialize_repr = info.serde_repr();
361         let check_unknown = if info.serialize_deny_unknown {
362             quote! {
363                 if value & !#all_variants != 0 {
364                     use #serde::de::Error;
365                     return #core::prelude::v1::Err(
366                         D::Error::custom("enumset contains unknown bits")
367                     )
368                 }
369             }
370         } else {
371             quote! { }
372         };
373         quote! {
374             fn serialize<S: #serde::Serializer>(
375                 set: #enumset::EnumSet<#name>, ser: S,
376             ) -> #core::result::Result<S::Ok, S::Error> {
377                 #serde::Serialize::serialize(&(set.__priv_repr as #serialize_repr), ser)
378             }
379             fn deserialize<'de, D: #serde::Deserializer<'de>>(
380                 de: D,
381             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
382                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
383                 #check_unknown
384                 #core::prelude::v1::Ok(#enumset::EnumSet {
385                     __priv_repr: (value & #all_variants) as #repr,
386                 })
387             }
388         }
389     };
390
391     #[cfg(not(feature = "serde"))]
392     let serde_ops = quote! { };
393
394     let is_uninhabited = info.variants.is_empty();
395     let is_zst = info.variants.len() == 1;
396     let into_impl = if is_uninhabited {
397         quote! {
398             fn enum_into_u32(self) -> u32 {
399                 panic!(concat!(stringify!(#name), " is uninhabited."))
400             }
401             unsafe fn enum_from_u32(val: u32) -> Self {
402                 panic!(concat!(stringify!(#name), " is uninhabited."))
403             }
404         }
405     } else if is_zst {
406         let variant = &info.variants[0].name;
407         quote! {
408             fn enum_into_u32(self) -> u32 {
409                 self as u32
410             }
411             unsafe fn enum_from_u32(val: u32) -> Self {
412                 #name::#variant
413             }
414         }
415     } else {
416         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
417         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
418
419         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
420             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
421         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
422             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
423
424         quote! {
425             fn enum_into_u32(self) -> u32 {
426                 self as u32
427             }
428             unsafe fn enum_from_u32(val: u32) -> Self {
429                 // We put these in const fields so the branches they guard aren't generated even
430                 // on -O0
431                 #(const #const_field: bool =
432                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
433                 match val {
434                     // Every valid variant value has an explicit branch. If they get optimized out,
435                     // great. If the representation has changed somehow, and they don't, oh well,
436                     // there's still no UB.
437                     #(#variant_value => #name::#variant_name,)*
438                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
439                     // still unreachable.
440                     #(x if #const_field => {
441                         let x = x as #int_type;
442                         *(&x as *const _ as *const #name)
443                     })*
444                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
445                     // transmute, but, oh well.
446                     _ => #core::hint::unreachable_unchecked(),
447                 }
448             }
449         }
450     };
451
452     let eq_impl = if is_uninhabited {
453         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
454     } else {
455         quote!((*self as u32) == (*other as u32))
456     };
457
458     // used in the enum_set! macro `const fn`s.
459     let self_as_repr_mask = if is_uninhabited {
460         quote! { 0 } // impossible anyway
461     } else {
462         quote! { 1 << self as #repr }
463     };
464
465     let super_impls = if info.no_super_impls {
466         quote! {}
467     } else {
468         quote! {
469             impl #core::cmp::PartialEq for #name {
470                 fn eq(&self, other: &Self) -> bool {
471                     #eq_impl
472                 }
473             }
474             impl #core::cmp::Eq for #name { }
475             #[allow(clippy::expl_impl_clone_on_copy)]
476             impl #core::clone::Clone for #name {
477                 fn clone(&self) -> Self {
478                     *self
479                 }
480             }
481             impl #core::marker::Copy for #name { }
482         }
483     };
484
485     quote! {
486         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
487             type Repr = #repr;
488             const ALL_BITS: Self::Repr = #all_variants;
489             #into_impl
490             #serde_ops
491         }
492
493         unsafe impl #enumset::EnumSetType for #name { }
494
495         #super_impls
496
497         impl #name {
498             /// Creates a new enumset with only this variant.
499             #[deprecated(note = "This method is an internal implementation detail generated by \
500                                  the `enumset` crate's procedural macro. It should not be used \
501                                  directly. Use `EnumSet::only` instead.")]
502             #[doc(hidden)]
503             pub const fn __impl_enumset_internal__const_only(self) -> #enumset::EnumSet<#name> {
504                 #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
505             }
506
507             /// Creates a new enumset with this variant added.
508             #[deprecated(note = "This method is an internal implementation detail generated by \
509                                  the `enumset` crate's procedural macro. It should not be used \
510                                  directly. Use the `|` operator instead.")]
511             #[doc(hidden)]
512             pub const fn __impl_enumset_internal__const_merge(
513                 self, chain: #enumset::EnumSet<#name>,
514             ) -> #enumset::EnumSet<#name> {
515                 #enumset::EnumSet { __priv_repr: chain.__priv_repr | #self_as_repr_mask }
516             }
517         }
518
519         #ops
520     }
521 }
522
523 #[proc_macro_derive(EnumSetType, attributes(enumset))]
524 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
525     let input: DeriveInput = parse_macro_input!(input);
526     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
527         Ok(attrs) => attrs,
528         Err(e) => return e.write_errors().into(),
529     };
530     match derive_enum_set_type_0(input, attrs) {
531         Ok(v) => v,
532         Err(e) => e.to_compile_error().into(),
533     }
534 }
535 fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
536     if !input.generics.params.is_empty() {
537         error(
538             input.generics.span(),
539             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
540         )
541     } else if let Data::Enum(data) = &input.data {
542         let mut info = EnumSetInfo::new(&input, attrs);
543         for attr in &input.attrs {
544             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
545                 let meta: Ident = attr.parse_args()?;
546                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
547             }
548         }
549         for variant in &data.variants {
550             info.push_variant(variant)?;
551         }
552         info.validate()?;
553         Ok(enum_set_type_impl(info).into())
554     } else {
555         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
556     }
557 }