]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
83ac1381cf89f0d44331bc42689bb53ba5a6a82c
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal};
8 use std::collections::HashSet;
9 use syn::{*, Result, Error};
10 use syn::export::Span;
11 use syn::spanned::Spanned;
12 use quote::*;
13
14 /// Helper function for emitting compile errors.
15 fn error<T>(span: Span, message: &str) -> Result<T> {
16     Err(Error::new(span, message))
17 }
18
19 /// Decodes the custom attributes for our custom derive.
20 #[derive(FromDeriveInput, Default)]
21 #[darling(attributes(enumset), default)]
22 struct EnumsetAttrs {
23     no_ops: bool,
24     serialize_as_list: bool,
25     serialize_deny_unknown: bool,
26     #[darling(default)]
27     serialize_repr: Option<String>,
28     #[darling(default)]
29     crate_name: Option<String>,
30 }
31
32 /// An variant in the enum set type.
33 struct EnumSetValue {
34     /// The name of the variant.
35     name: Ident,
36     /// The discriminant of the variant.
37     variant_repr: u32,
38 }
39
40 /// Stores information about the enum set type.
41 #[allow(dead_code)]
42 struct EnumSetInfo {
43     /// The name of the enum.
44     name: Ident,
45     /// The crate name to use.
46     crate_name: Option<Ident>,
47     /// The numeric type to serialize the enum as.
48     explicit_serde_repr: Option<Ident>,
49     /// Whether the underlying repr of the enum supports negative values.
50     has_signed_repr: bool,
51     /// Whether the underlying repr of the enum supports values higher than 2^32.
52     has_large_repr: bool,
53     /// A list of variants in the enum.
54     variants: Vec<EnumSetValue>,
55
56     /// The highest encountered variant discriminant.
57     max_discrim: u32,
58     /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
59     cur_discrim: u32,
60     /// A list of variant names that are already in use.
61     used_variant_names: HashSet<String>,
62     /// A list of variant discriminants that are already in use.
63     used_discriminants: HashSet<u32>,
64
65     /// Avoid generating operator overloads on the enum type.
66     no_ops: bool,
67     /// Serialize the enum as a list.
68     serialize_as_list: bool,
69     /// Disallow unknown bits while deserializing the enum.
70     serialize_deny_unknown: bool,
71 }
72 impl EnumSetInfo {
73     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
74         EnumSetInfo {
75             name: input.ident.clone(),
76             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
77             explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
78             has_signed_repr: false,
79             has_large_repr: false,
80             variants: Vec::new(),
81             max_discrim: 0,
82             cur_discrim: 0,
83             used_variant_names: HashSet::new(),
84             used_discriminants: HashSet::new(),
85             no_ops: attrs.no_ops,
86             serialize_as_list: attrs.serialize_as_list,
87             serialize_deny_unknown: attrs.serialize_deny_unknown
88         }
89     }
90
91     /// Sets an explicit repr for the enumset.
92     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
93         // Check whether the repr is supported, and if so, set some flags for better error
94         // messages later on.
95         match repr {
96             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
97             "usize" | "u64" | "u128" => {
98                 self.has_large_repr = true;
99                 Ok(())
100             }
101             "i8" | "i16" | "i32" => {
102                 self.has_signed_repr = true;
103                 Ok(())
104             }
105             "isize" | "i64" | "i128" => {
106                 self.has_signed_repr = true;
107                 self.has_large_repr = true;
108                 Ok(())
109             }
110             _ => error(attr_span, "Unsupported repr.")
111         }
112     }
113     /// Adds a variant to the enumset.
114     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
115         if self.used_variant_names.contains(&variant.ident.to_string()) {
116             error(variant.span(), "Duplicated variant name.")
117         } else if let Fields::Unit = variant.fields {
118             // Parse the discriminant.
119             if let Some((_, expr)) = &variant.discriminant {
120                 let discriminant_fail_message = format!(
121                     "Enum set discriminants must be `u32`s.{}",
122                     if self.has_signed_repr || self.has_large_repr {
123                         format!(
124                             " ({} discrimiants are still unsupported with reprs that allow them.)",
125                             if self.has_large_repr {
126                                 "larger"
127                             } else if self.has_signed_repr {
128                                 "negative"
129                             } else {
130                                 "larger or negative"
131                             }
132                         )
133                     } else {
134                         String::new()
135                     },
136                 );
137                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
138                     match i.base10_parse() {
139                         Ok(val) => self.cur_discrim = val,
140                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
141                     }
142                 } else {
143                     error(variant.span(), &discriminant_fail_message)?;
144                 }
145             }
146
147             // Validate the discriminant.
148             let discriminant = self.cur_discrim;
149             if discriminant >= 128 {
150                 let message = if self.variants.len() <= 127 {
151                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
152                 } else {
153                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
154                 };
155                 error(variant.span(), message)?;
156             }
157             if self.used_discriminants.contains(&discriminant) {
158                 error(variant.span(), "Duplicated enum discriminant.")?;
159             }
160
161             // Add the variant to the info.
162             self.cur_discrim += 1;
163             if discriminant > self.max_discrim {
164                 self.max_discrim = discriminant;
165             }
166             self.variants.push(EnumSetValue {
167                 name: variant.ident.clone(),
168                 variant_repr: discriminant,
169             });
170             self.used_variant_names.insert(variant.ident.to_string());
171             self.used_discriminants.insert(discriminant);
172
173             Ok(())
174         } else {
175             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
176         }
177     }
178     /// Validate the enumset type.
179     fn validate(&self) -> Result<()> {
180         // Check if all bits of the bitset can fit in the serialization representation.
181         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
182             let is_overflowed = match explicit_serde_repr.to_string().as_str() {
183                 "u8" => self.max_discrim >= 8,
184                 "u16" => self.max_discrim >= 16,
185                 "u32" => self.max_discrim >= 32,
186                 "u64" => self.max_discrim >= 64,
187                 "u128" => self.max_discrim >= 128,
188                 _ => error(
189                     Span::call_site(),
190                     "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
191                 )?,
192             };
193             if is_overflowed {
194                 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
195             }
196         }
197         Ok(())
198     }
199
200     /// Computes the underlying type used to store the enumset.
201     fn enumset_repr(&self) -> SynTokenStream {
202         if self.max_discrim <= 7 {
203             quote! { u8 }
204         } else if self.max_discrim <= 15 {
205             quote! { u16 }
206         } else if self.max_discrim <= 31 {
207             quote! { u32 }
208         } else if self.max_discrim <= 63 {
209             quote! { u64 }
210         } else if self.max_discrim <= 127 {
211             quote! { u128 }
212         } else {
213             panic!("max_variant > 127?")
214         }
215     }
216     /// Computes the underlying type used to serialize the enumset.
217     #[cfg(feature = "serde")]
218     fn serde_repr(&self) -> SynTokenStream {
219         if let Some(serde_repr) = &self.explicit_serde_repr {
220             quote! { #serde_repr }
221         } else {
222             self.enumset_repr()
223         }
224     }
225
226     /// Returns a bitmask of all variants in the set.
227     fn all_variants(&self) -> u128 {
228         let mut accum = 0u128;
229         for variant in &self.variants {
230             assert!(variant.variant_repr <= 127);
231             accum |= 1u128 << variant.variant_repr as u128;
232         }
233         accum
234     }
235 }
236
237 /// Generates the actual `EnumSetType` impl.
238 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
239     let name = &info.name;
240     let enumset = match &info.crate_name {
241         Some(crate_name) => quote!(::#crate_name),
242         None => quote!(::enumset),
243     };
244     let typed_enumset = quote!(#enumset::EnumSet<#name>);
245     let core = quote!(#enumset::__internal::core_export);
246
247     let repr = info.enumset_repr();
248     let all_variants = Literal::u128_unsuffixed(info.all_variants());
249
250     let ops = if info.no_ops {
251         quote! {}
252     } else {
253         quote! {
254             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
255                 type Output = #typed_enumset;
256                 fn sub(self, other: O) -> Self::Output {
257                     #enumset::EnumSet::only(self) - other.into()
258                 }
259             }
260             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
261                 type Output = #typed_enumset;
262                 fn bitand(self, other: O) -> Self::Output {
263                     #enumset::EnumSet::only(self) & other.into()
264                 }
265             }
266             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
267                 type Output = #typed_enumset;
268                 fn bitor(self, other: O) -> Self::Output {
269                     #enumset::EnumSet::only(self) | other.into()
270                 }
271             }
272             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
273                 type Output = #typed_enumset;
274                 fn bitxor(self, other: O) -> Self::Output {
275                     #enumset::EnumSet::only(self) ^ other.into()
276                 }
277             }
278             impl #core::ops::Not for #name {
279                 type Output = #typed_enumset;
280                 fn not(self) -> Self::Output {
281                     !#enumset::EnumSet::only(self)
282                 }
283             }
284             impl #core::cmp::PartialEq<#typed_enumset> for #name {
285                 fn eq(&self, other: &#typed_enumset) -> bool {
286                     #enumset::EnumSet::only(*self) == *other
287                 }
288             }
289         }
290     };
291
292
293     #[cfg(feature = "serde")]
294     let serde = quote!(#enumset::__internal::serde);
295
296     #[cfg(feature = "serde")]
297     let serde_ops = if info.serialize_as_list {
298         let expecting_str = format!("a list of {}", name);
299         quote! {
300             fn serialize<S: #serde::Serializer>(
301                 set: #enumset::EnumSet<#name>, ser: S,
302             ) -> #core::result::Result<S::Ok, S::Error> {
303                 use #serde::ser::SerializeSeq;
304                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
305                 for bit in set {
306                     seq.serialize_element(&bit)?;
307                 }
308                 seq.end()
309             }
310             fn deserialize<'de, D: #serde::Deserializer<'de>>(
311                 de: D,
312             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
313                 struct Visitor;
314                 impl <'de> #serde::de::Visitor<'de> for Visitor {
315                     type Value = #enumset::EnumSet<#name>;
316                     fn expecting(
317                         &self, formatter: &mut #core::fmt::Formatter,
318                     ) -> #core::fmt::Result {
319                         write!(formatter, #expecting_str)
320                     }
321                     fn visit_seq<A>(
322                         mut self, mut seq: A,
323                     ) -> #core::result::Result<Self::Value, A::Error> where
324                         A: #serde::de::SeqAccess<'de>
325                     {
326                         let mut accum = #enumset::EnumSet::<#name>::new();
327                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
328                             accum |= val;
329                         }
330                         #core::prelude::v1::Ok(accum)
331                     }
332                 }
333                 de.deserialize_seq(Visitor)
334             }
335         }
336     } else {
337         let serialize_repr = info.serde_repr();
338         let check_unknown = if info.serialize_deny_unknown {
339             quote! {
340                 if value & !#all_variants != 0 {
341                     use #serde::de::Error;
342                     return #core::prelude::v1::Err(
343                         D::Error::custom("enumset contains unknown bits")
344                     )
345                 }
346             }
347         } else {
348             quote! { }
349         };
350         quote! {
351             fn serialize<S: #serde::Serializer>(
352                 set: #enumset::EnumSet<#name>, ser: S,
353             ) -> #core::result::Result<S::Ok, S::Error> {
354                 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
355             }
356             fn deserialize<'de, D: #serde::Deserializer<'de>>(
357                 de: D,
358             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
359                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
360                 #check_unknown
361                 #core::prelude::v1::Ok(#enumset::EnumSet {
362                     __enumset_underlying: (value & #all_variants) as #repr,
363                 })
364             }
365         }
366     };
367
368     #[cfg(not(feature = "serde"))]
369     let serde_ops = quote! { };
370
371     let is_uninhabited = info.variants.is_empty();
372     let is_zst = info.variants.len() == 1;
373     let into_impl = if is_uninhabited {
374         quote! {
375             fn enum_into_u32(self) -> u32 {
376                 panic!(concat!(stringify!(#name), " is uninhabited."))
377             }
378             unsafe fn enum_from_u32(val: u32) -> Self {
379                 panic!(concat!(stringify!(#name), " is uninhabited."))
380             }
381         }
382     } else if is_zst {
383         let variant = &info.variants[0].name;
384         quote! {
385             fn enum_into_u32(self) -> u32 {
386                 self as u32
387             }
388             unsafe fn enum_from_u32(val: u32) -> Self {
389                 #name::#variant
390             }
391         }
392     } else {
393         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
394         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
395
396         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
397             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
398         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
399             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
400
401         quote! {
402             fn enum_into_u32(self) -> u32 {
403                 self as u32
404             }
405             unsafe fn enum_from_u32(val: u32) -> Self {
406                 // We put these in const fields so the branches they guard aren't generated even
407                 // on -O0
408                 #(const #const_field: bool =
409                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
410                 match val {
411                     // Every valid variant value has an explicit branch. If they get optimized out,
412                     // great. If the representation has changed somehow, and they don't, oh well,
413                     // there's still no UB.
414                     #(#variant_value => #name::#variant_name,)*
415                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
416                     // still unreachable.
417                     #(x if #const_field => {
418                         let x = x as #int_type;
419                         *(&x as *const _ as *const #name)
420                     })*
421                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
422                     // transmute, but, oh well.
423                     _ => #core::hint::unreachable_unchecked(),
424                 }
425             }
426         }
427     };
428
429     let eq_impl = if is_uninhabited {
430         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
431     } else {
432         quote!((*self as u32) == (*other as u32))
433     };
434
435     quote! {
436         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
437             type Repr = #repr;
438             const ALL_BITS: Self::Repr = #all_variants;
439             #into_impl
440             #serde_ops
441         }
442
443         unsafe impl #enumset::EnumSetType for #name { }
444
445         impl #core::cmp::PartialEq for #name {
446             fn eq(&self, other: &Self) -> bool {
447                 #eq_impl
448             }
449         }
450         impl #core::cmp::Eq for #name { }
451         impl #core::clone::Clone for #name {
452             fn clone(&self) -> Self {
453                 *self
454             }
455         }
456         impl #core::marker::Copy for #name { }
457
458         #ops
459     }
460 }
461
462 /// A wrapper that parses the input enum.
463 #[proc_macro_derive(EnumSetType, attributes(enumset))]
464 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
465     let input: DeriveInput = parse_macro_input!(input);
466     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
467         Ok(attrs) => attrs,
468         Err(e) => return e.write_errors().into(),
469     };
470     match derive_enum_set_type_0(input, attrs) {
471         Ok(v) => v,
472         Err(e) => e.to_compile_error().into(),
473     }
474 }
475 fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
476     if !input.generics.params.is_empty() {
477         error(
478             input.generics.span(),
479             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
480         )
481     } else if let Data::Enum(data) = &input.data {
482         let mut info = EnumSetInfo::new(&input, attrs);
483         for attr in &input.attrs {
484             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
485                 let meta: Ident = attr.parse_args()?;
486                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
487             }
488         }
489         for variant in &data.variants {
490             info.push_variant(variant)?;
491         }
492         info.validate()?;
493         Ok(enum_set_type_impl(info).into())
494     } else {
495         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
496     }
497 }