]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Use proc-macro-crate to allow easier use when enumset is renamed.
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal, Span};
8 use std::collections::HashSet;
9 use proc_macro_crate::FoundCrate;
10 use syn::{*, Result, Error};
11 use syn::spanned::Spanned;
12 use quote::*;
13
14 /// Helper function for emitting compile errors.
15 fn error<T>(span: Span, message: &str) -> Result<T> {
16     Err(Error::new(span, message))
17 }
18
19 /// Decodes the custom attributes for our custom derive.
20 #[derive(FromDeriveInput, Default)]
21 #[darling(attributes(enumset), default)]
22 struct EnumsetAttrs {
23     no_ops: bool,
24     no_super_impls: bool,
25     serialize_as_list: bool,
26     serialize_deny_unknown: bool,
27     #[darling(default)]
28     serialize_repr: Option<String>,
29     #[darling(default)]
30     crate_name: Option<String>,
31 }
32
33 /// An variant in the enum set type.
34 struct EnumSetValue {
35     /// The name of the variant.
36     name: Ident,
37     /// The discriminant of the variant.
38     variant_repr: u32,
39 }
40
41 /// Stores information about the enum set type.
42 #[allow(dead_code)]
43 struct EnumSetInfo {
44     /// The name of the enum.
45     name: Ident,
46     /// The crate name to use.
47     crate_name: Option<Ident>,
48     /// The numeric type to serialize the enum as.
49     explicit_serde_repr: Option<Ident>,
50     /// Whether the underlying repr of the enum supports negative values.
51     has_signed_repr: bool,
52     /// Whether the underlying repr of the enum supports values higher than 2^32.
53     has_large_repr: bool,
54     /// A list of variants in the enum.
55     variants: Vec<EnumSetValue>,
56
57     /// The highest encountered variant discriminant.
58     max_discrim: u32,
59     /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
60     cur_discrim: u32,
61     /// A list of variant names that are already in use.
62     used_variant_names: HashSet<String>,
63     /// A list of variant discriminants that are already in use.
64     used_discriminants: HashSet<u32>,
65
66     /// Avoid generating operator overloads on the enum type.
67     no_ops: bool,
68     /// Avoid generating implementations for `Clone`, `Copy`, `Eq`, and `PartialEq`.
69     no_super_impls: bool,
70     /// Serialize the enum as a list.
71     serialize_as_list: bool,
72     /// Disallow unknown bits while deserializing the enum.
73     serialize_deny_unknown: bool,
74 }
75 impl EnumSetInfo {
76     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
77         EnumSetInfo {
78             name: input.ident.clone(),
79             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
80             explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
81             has_signed_repr: false,
82             has_large_repr: false,
83             variants: Vec::new(),
84             max_discrim: 0,
85             cur_discrim: 0,
86             used_variant_names: HashSet::new(),
87             used_discriminants: HashSet::new(),
88             no_ops: attrs.no_ops,
89             no_super_impls: attrs.no_super_impls,
90             serialize_as_list: attrs.serialize_as_list,
91             serialize_deny_unknown: attrs.serialize_deny_unknown
92         }
93     }
94
95     /// Sets an explicit repr for the enumset.
96     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
97         // Check whether the repr is supported, and if so, set some flags for better error
98         // messages later on.
99         match repr {
100             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
101             "usize" | "u64" | "u128" => {
102                 self.has_large_repr = true;
103                 Ok(())
104             }
105             "i8" | "i16" | "i32" => {
106                 self.has_signed_repr = true;
107                 Ok(())
108             }
109             "isize" | "i64" | "i128" => {
110                 self.has_signed_repr = true;
111                 self.has_large_repr = true;
112                 Ok(())
113             }
114             _ => error(attr_span, "Unsupported repr.")
115         }
116     }
117     /// Adds a variant to the enumset.
118     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
119         if self.used_variant_names.contains(&variant.ident.to_string()) {
120             error(variant.span(), "Duplicated variant name.")
121         } else if let Fields::Unit = variant.fields {
122             // Parse the discriminant.
123             if let Some((_, expr)) = &variant.discriminant {
124                 let discriminant_fail_message = format!(
125                     "Enum set discriminants must be `u32`s.{}",
126                     if self.has_signed_repr || self.has_large_repr {
127                         format!(
128                             " ({} discrimiants are still unsupported with reprs that allow them.)",
129                             if self.has_large_repr {
130                                 "larger"
131                             } else if self.has_signed_repr {
132                                 "negative"
133                             } else {
134                                 "larger or negative"
135                             }
136                         )
137                     } else {
138                         String::new()
139                     },
140                 );
141                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
142                     match i.base10_parse() {
143                         Ok(val) => self.cur_discrim = val,
144                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
145                     }
146                 } else {
147                     error(variant.span(), &discriminant_fail_message)?;
148                 }
149             }
150
151             // Validate the discriminant.
152             let discriminant = self.cur_discrim;
153             if discriminant >= 128 {
154                 let message = if self.variants.len() <= 127 {
155                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
156                 } else {
157                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
158                 };
159                 error(variant.span(), message)?;
160             }
161             if self.used_discriminants.contains(&discriminant) {
162                 error(variant.span(), "Duplicated enum discriminant.")?;
163             }
164
165             // Add the variant to the info.
166             self.cur_discrim += 1;
167             if discriminant > self.max_discrim {
168                 self.max_discrim = discriminant;
169             }
170             self.variants.push(EnumSetValue {
171                 name: variant.ident.clone(),
172                 variant_repr: discriminant,
173             });
174             self.used_variant_names.insert(variant.ident.to_string());
175             self.used_discriminants.insert(discriminant);
176
177             Ok(())
178         } else {
179             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
180         }
181     }
182     /// Validate the enumset type.
183     fn validate(&self) -> Result<()> {
184         // Check if all bits of the bitset can fit in the serialization representation.
185         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
186             let is_overflowed = match explicit_serde_repr.to_string().as_str() {
187                 "u8" => self.max_discrim >= 8,
188                 "u16" => self.max_discrim >= 16,
189                 "u32" => self.max_discrim >= 32,
190                 "u64" => self.max_discrim >= 64,
191                 "u128" => self.max_discrim >= 128,
192                 _ => error(
193                     Span::call_site(),
194                     "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
195                 )?,
196             };
197             if is_overflowed {
198                 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
199             }
200         }
201         Ok(())
202     }
203
204     /// Computes the underlying type used to store the enumset.
205     fn enumset_repr(&self) -> SynTokenStream {
206         if self.max_discrim <= 7 {
207             quote! { u8 }
208         } else if self.max_discrim <= 15 {
209             quote! { u16 }
210         } else if self.max_discrim <= 31 {
211             quote! { u32 }
212         } else if self.max_discrim <= 63 {
213             quote! { u64 }
214         } else if self.max_discrim <= 127 {
215             quote! { u128 }
216         } else {
217             panic!("max_variant > 127?")
218         }
219     }
220     /// Computes the underlying type used to serialize the enumset.
221     #[cfg(feature = "serde")]
222     fn serde_repr(&self) -> SynTokenStream {
223         if let Some(serde_repr) = &self.explicit_serde_repr {
224             quote! { #serde_repr }
225         } else {
226             self.enumset_repr()
227         }
228     }
229
230     /// Returns a bitmask of all variants in the set.
231     fn all_variants(&self) -> u128 {
232         let mut accum = 0u128;
233         for variant in &self.variants {
234             assert!(variant.variant_repr <= 127);
235             accum |= 1u128 << variant.variant_repr as u128;
236         }
237         accum
238     }
239 }
240
241 /// Generates the actual `EnumSetType` impl.
242 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
243     let name = &info.name;
244     let enumset = match &info.crate_name {
245         Some(crate_name) => quote!(::#crate_name),
246         None => {
247             let crate_name = proc_macro_crate::crate_name("enumset");
248             match crate_name {
249                 Ok(FoundCrate::Name(name)) => {
250                     let ident = Ident::new(&name, Span::call_site());
251                     quote!(::#ident)
252                 }
253                 _ => quote!(::enumset),
254             }
255         },
256     };
257     let typed_enumset = quote!(#enumset::EnumSet<#name>);
258     let core = quote!(#enumset::__internal::core_export);
259
260     let repr = info.enumset_repr();
261     let all_variants = Literal::u128_unsuffixed(info.all_variants());
262
263     let ops = if info.no_ops {
264         quote! {}
265     } else {
266         quote! {
267             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
268                 type Output = #typed_enumset;
269                 fn sub(self, other: O) -> Self::Output {
270                     #enumset::EnumSet::only(self) - other.into()
271                 }
272             }
273             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
274                 type Output = #typed_enumset;
275                 fn bitand(self, other: O) -> Self::Output {
276                     #enumset::EnumSet::only(self) & other.into()
277                 }
278             }
279             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
280                 type Output = #typed_enumset;
281                 fn bitor(self, other: O) -> Self::Output {
282                     #enumset::EnumSet::only(self) | other.into()
283                 }
284             }
285             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
286                 type Output = #typed_enumset;
287                 fn bitxor(self, other: O) -> Self::Output {
288                     #enumset::EnumSet::only(self) ^ other.into()
289                 }
290             }
291             impl #core::ops::Not for #name {
292                 type Output = #typed_enumset;
293                 fn not(self) -> Self::Output {
294                     !#enumset::EnumSet::only(self)
295                 }
296             }
297             impl #core::cmp::PartialEq<#typed_enumset> for #name {
298                 fn eq(&self, other: &#typed_enumset) -> bool {
299                     #enumset::EnumSet::only(*self) == *other
300                 }
301             }
302         }
303     };
304
305
306     #[cfg(feature = "serde")]
307     let serde = quote!(#enumset::__internal::serde);
308
309     #[cfg(feature = "serde")]
310     let serde_ops = if info.serialize_as_list {
311         let expecting_str = format!("a list of {}", name);
312         quote! {
313             fn serialize<S: #serde::Serializer>(
314                 set: #enumset::EnumSet<#name>, ser: S,
315             ) -> #core::result::Result<S::Ok, S::Error> {
316                 use #serde::ser::SerializeSeq;
317                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
318                 for bit in set {
319                     seq.serialize_element(&bit)?;
320                 }
321                 seq.end()
322             }
323             fn deserialize<'de, D: #serde::Deserializer<'de>>(
324                 de: D,
325             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
326                 struct Visitor;
327                 impl <'de> #serde::de::Visitor<'de> for Visitor {
328                     type Value = #enumset::EnumSet<#name>;
329                     fn expecting(
330                         &self, formatter: &mut #core::fmt::Formatter,
331                     ) -> #core::fmt::Result {
332                         write!(formatter, #expecting_str)
333                     }
334                     fn visit_seq<A>(
335                         mut self, mut seq: A,
336                     ) -> #core::result::Result<Self::Value, A::Error> where
337                         A: #serde::de::SeqAccess<'de>
338                     {
339                         let mut accum = #enumset::EnumSet::<#name>::new();
340                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
341                             accum |= val;
342                         }
343                         #core::prelude::v1::Ok(accum)
344                     }
345                 }
346                 de.deserialize_seq(Visitor)
347             }
348         }
349     } else {
350         let serialize_repr = info.serde_repr();
351         let check_unknown = if info.serialize_deny_unknown {
352             quote! {
353                 if value & !#all_variants != 0 {
354                     use #serde::de::Error;
355                     return #core::prelude::v1::Err(
356                         D::Error::custom("enumset contains unknown bits")
357                     )
358                 }
359             }
360         } else {
361             quote! { }
362         };
363         quote! {
364             fn serialize<S: #serde::Serializer>(
365                 set: #enumset::EnumSet<#name>, ser: S,
366             ) -> #core::result::Result<S::Ok, S::Error> {
367                 #serde::Serialize::serialize(&(set.__priv_repr as #serialize_repr), ser)
368             }
369             fn deserialize<'de, D: #serde::Deserializer<'de>>(
370                 de: D,
371             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
372                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
373                 #check_unknown
374                 #core::prelude::v1::Ok(#enumset::EnumSet {
375                     __priv_repr: (value & #all_variants) as #repr,
376                 })
377             }
378         }
379     };
380
381     #[cfg(not(feature = "serde"))]
382     let serde_ops = quote! { };
383
384     let is_uninhabited = info.variants.is_empty();
385     let is_zst = info.variants.len() == 1;
386     let into_impl = if is_uninhabited {
387         quote! {
388             fn enum_into_u32(self) -> u32 {
389                 panic!(concat!(stringify!(#name), " is uninhabited."))
390             }
391             unsafe fn enum_from_u32(val: u32) -> Self {
392                 panic!(concat!(stringify!(#name), " is uninhabited."))
393             }
394         }
395     } else if is_zst {
396         let variant = &info.variants[0].name;
397         quote! {
398             fn enum_into_u32(self) -> u32 {
399                 self as u32
400             }
401             unsafe fn enum_from_u32(val: u32) -> Self {
402                 #name::#variant
403             }
404         }
405     } else {
406         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
407         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
408
409         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
410             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
411         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
412             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
413
414         quote! {
415             fn enum_into_u32(self) -> u32 {
416                 self as u32
417             }
418             unsafe fn enum_from_u32(val: u32) -> Self {
419                 // We put these in const fields so the branches they guard aren't generated even
420                 // on -O0
421                 #(const #const_field: bool =
422                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
423                 match val {
424                     // Every valid variant value has an explicit branch. If they get optimized out,
425                     // great. If the representation has changed somehow, and they don't, oh well,
426                     // there's still no UB.
427                     #(#variant_value => #name::#variant_name,)*
428                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
429                     // still unreachable.
430                     #(x if #const_field => {
431                         let x = x as #int_type;
432                         *(&x as *const _ as *const #name)
433                     })*
434                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
435                     // transmute, but, oh well.
436                     _ => #core::hint::unreachable_unchecked(),
437                 }
438             }
439         }
440     };
441
442     let eq_impl = if is_uninhabited {
443         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
444     } else {
445         quote!((*self as u32) == (*other as u32))
446     };
447
448     // used in the enum_set! macro `const fn`s.
449     let self_as_repr_mask = if is_uninhabited {
450         quote! { 0 } // impossible anyway
451     } else {
452         quote! { 1 << self as #repr }
453     };
454
455     let super_impls = if info.no_super_impls {
456         quote! {}
457     } else {
458         quote! {
459             impl #core::cmp::PartialEq for #name {
460                 fn eq(&self, other: &Self) -> bool {
461                     #eq_impl
462                 }
463             }
464             impl #core::cmp::Eq for #name { }
465             #[allow(clippy::expl_impl_clone_on_copy)]
466             impl #core::clone::Clone for #name {
467                 fn clone(&self) -> Self {
468                     *self
469                 }
470             }
471             impl #core::marker::Copy for #name { }
472         }
473     };
474
475     quote! {
476         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
477             type Repr = #repr;
478             const ALL_BITS: Self::Repr = #all_variants;
479             #into_impl
480             #serde_ops
481         }
482
483         unsafe impl #enumset::EnumSetType for #name { }
484
485         #super_impls
486
487         impl #name {
488             /// Creates a new enumset with only this variant.
489             #[deprecated(note = "This method is an internal implementation detail generated by \
490                                  the `enumset` crate's procedural macro. It should not be used \
491                                  directly. Use `EnumSet::only` instead.")]
492             #[doc(hidden)]
493             pub const fn __impl_enumset_internal__const_only(self) -> #enumset::EnumSet<#name> {
494                 #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
495             }
496
497             /// Creates a new enumset with this variant added.
498             #[deprecated(note = "This method is an internal implementation detail generated by \
499                                  the `enumset` crate's procedural macro. It should not be used \
500                                  directly. Use the `|` operator instead.")]
501             #[doc(hidden)]
502             pub const fn __impl_enumset_internal__const_merge(
503                 self, chain: #enumset::EnumSet<#name>,
504             ) -> #enumset::EnumSet<#name> {
505                 #enumset::EnumSet { __priv_repr: chain.__priv_repr | #self_as_repr_mask }
506             }
507         }
508
509         #ops
510     }
511 }
512
513 #[proc_macro_derive(EnumSetType, attributes(enumset))]
514 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
515     let input: DeriveInput = parse_macro_input!(input);
516     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
517         Ok(attrs) => attrs,
518         Err(e) => return e.write_errors().into(),
519     };
520     match derive_enum_set_type_0(input, attrs) {
521         Ok(v) => v,
522         Err(e) => e.to_compile_error().into(),
523     }
524 }
525 fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
526     if !input.generics.params.is_empty() {
527         error(
528             input.generics.span(),
529             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
530         )
531     } else if let Data::Enum(data) = &input.data {
532         let mut info = EnumSetInfo::new(&input, attrs);
533         for attr in &input.attrs {
534             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
535                 let meta: Ident = attr.parse_args()?;
536                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
537             }
538         }
539         for variant in &data.variants {
540             info.push_variant(variant)?;
541         }
542         info.validate()?;
543         Ok(enum_set_type_impl(info).into())
544     } else {
545         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
546     }
547 }