]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Add serialize_as_map
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit = "256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{Literal, Span, TokenStream as SynTokenStream};
8 use quote::*;
9 use std::{collections::HashSet, fmt::Display};
10 use syn::spanned::Spanned;
11 use syn::{Error, Result, *};
12
13 /// Helper function for emitting compile errors.
14 fn error<T>(span: Span, message: impl Display) -> Result<T> {
15     Err(Error::new(span, message))
16 }
17
18 /// Decodes the custom attributes for our custom derive.
19 #[derive(FromDeriveInput, Default)]
20 #[darling(attributes(enumset), default)]
21 struct EnumsetAttrs {
22     no_ops: bool,
23     no_super_impls: bool,
24     #[darling(default)]
25     repr: Option<String>,
26     serialize_as_list: bool,
27     serialize_as_map: bool,
28     serialize_deny_unknown: bool,
29     #[darling(default)]
30     serialize_repr: Option<String>,
31     #[darling(default)]
32     crate_name: Option<String>,
33 }
34
35 /// An variant in the enum set type.
36 struct EnumSetValue {
37     /// The name of the variant.
38     name: Ident,
39     /// The discriminant of the variant.
40     variant_repr: u32,
41 }
42
43 /// Stores information about the enum set type.
44 #[allow(dead_code)]
45 struct EnumSetInfo {
46     /// The name of the enum.
47     name: Ident,
48     /// The crate name to use.
49     crate_name: Option<Ident>,
50     /// The numeric type to represent the `EnumSet` as in memory.
51     explicit_mem_repr: Option<Ident>,
52     /// The numeric type to serialize the enum as.
53     explicit_serde_repr: Option<Ident>,
54     /// Whether the underlying repr of the enum supports negative values.
55     has_signed_repr: bool,
56     /// Whether the underlying repr of the enum supports values higher than 2^32.
57     has_large_repr: bool,
58     /// A list of variants in the enum.
59     variants: Vec<EnumSetValue>,
60
61     /// The highest encountered variant discriminant.
62     max_discrim: u32,
63     /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
64     cur_discrim: u32,
65     /// A list of variant names that are already in use.
66     used_variant_names: HashSet<String>,
67     /// A list of variant discriminants that are already in use.
68     used_discriminants: HashSet<u32>,
69
70     /// Avoid generating operator overloads on the enum type.
71     no_ops: bool,
72     /// Avoid generating implementations for `Clone`, `Copy`, `Eq`, and `PartialEq`.
73     no_super_impls: bool,
74     /// Serialize the enum as a list.
75     serialize_as_list: bool,
76     /// Serialize the enum as a map.
77     serialize_as_map: bool,
78     /// Disallow unknown bits while deserializing the enum.
79     serialize_deny_unknown: bool,
80 }
81 impl EnumSetInfo {
82     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
83         EnumSetInfo {
84             name: input.ident.clone(),
85             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
86             explicit_mem_repr: attrs.repr.map(|x| Ident::new(&x, Span::call_site())),
87             explicit_serde_repr: attrs
88                 .serialize_repr
89                 .map(|x| Ident::new(&x, Span::call_site())),
90             has_signed_repr: false,
91             has_large_repr: false,
92             variants: Vec::new(),
93             max_discrim: 0,
94             cur_discrim: 0,
95             used_variant_names: HashSet::new(),
96             used_discriminants: HashSet::new(),
97             no_ops: attrs.no_ops,
98             no_super_impls: attrs.no_super_impls,
99             serialize_as_list: attrs.serialize_as_list,
100             serialize_as_map: attrs.serialize_as_map,
101             serialize_deny_unknown: attrs.serialize_deny_unknown,
102         }
103     }
104
105     /// Sets an explicit repr for the enumset.
106     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
107         // Check whether the repr is supported, and if so, set some flags for better error
108         // messages later on.
109         match repr {
110             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
111             "usize" | "u64" | "u128" => {
112                 self.has_large_repr = true;
113                 Ok(())
114             }
115             "i8" | "i16" | "i32" => {
116                 self.has_signed_repr = true;
117                 Ok(())
118             }
119             "isize" | "i64" | "i128" => {
120                 self.has_signed_repr = true;
121                 self.has_large_repr = true;
122                 Ok(())
123             }
124             _ => error(attr_span, "Unsupported repr."),
125         }
126     }
127     /// Adds a variant to the enumset.
128     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
129         if self.used_variant_names.contains(&variant.ident.to_string()) {
130             error(variant.span(), "Duplicated variant name.")
131         } else if let Fields::Unit = variant.fields {
132             // Parse the discriminant.
133             if let Some((_, expr)) = &variant.discriminant {
134                 let discriminant_fail_message = format!(
135                     "Enum set discriminants must be `u32`s.{}",
136                     if self.has_signed_repr || self.has_large_repr {
137                         format!(
138                             " ({} discrimiants are still unsupported with reprs that allow them.)",
139                             if self.has_large_repr {
140                                 "larger"
141                             } else if self.has_signed_repr {
142                                 "negative"
143                             } else {
144                                 "larger or negative"
145                             }
146                         )
147                     } else {
148                         String::new()
149                     },
150                 );
151                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
152                     match i.base10_parse() {
153                         Ok(val) => self.cur_discrim = val,
154                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
155                     }
156                 } else {
157                     error(variant.span(), &discriminant_fail_message)?;
158                 }
159             }
160
161             // Validate the discriminant.
162             let discriminant = self.cur_discrim;
163             if discriminant >= 128 {
164                 let message = if self.variants.len() <= 127 {
165                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
166                 } else {
167                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
168                 };
169                 error(variant.span(), message)?;
170             }
171             if self.used_discriminants.contains(&discriminant) {
172                 error(variant.span(), "Duplicated enum discriminant.")?;
173             }
174
175             // Add the variant to the info.
176             self.cur_discrim += 1;
177             if discriminant > self.max_discrim {
178                 self.max_discrim = discriminant;
179             }
180             self.variants
181                 .push(EnumSetValue { name: variant.ident.clone(), variant_repr: discriminant });
182             self.used_variant_names.insert(variant.ident.to_string());
183             self.used_discriminants.insert(discriminant);
184
185             Ok(())
186         } else {
187             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
188         }
189     }
190     /// Validate the enumset type.
191     fn validate(&self) -> Result<()> {
192         fn do_check(ty: &str, max_discrim: u32, what: &str) -> Result<()> {
193             let is_overflowed = match ty {
194                 "u8" => max_discrim >= 8,
195                 "u16" => max_discrim >= 16,
196                 "u32" => max_discrim >= 32,
197                 "u64" => max_discrim >= 64,
198                 "u128" => max_discrim >= 128,
199                 _ => error(
200                     Span::call_site(),
201                     format!(
202                         "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for {}.",
203                         what
204                     ),
205                 )?,
206             };
207             if is_overflowed {
208                 error(Span::call_site(), format!("{} cannot be smaller than bitset.", what))?;
209             }
210             Ok(())
211         }
212
213         // Check if all bits of the bitset can fit in the serialization representation.
214         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
215             do_check(&explicit_serde_repr.to_string(), self.max_discrim, "serialize_repr")?;
216         }
217
218         // Check if all bits of the bitset can fit in the memory representation, if one was given.
219         if let Some(explicit_mem_repr) = &self.explicit_mem_repr {
220             do_check(&explicit_mem_repr.to_string(), self.max_discrim, "repr")?;
221         }
222         Ok(())
223     }
224
225     /// Computes the underlying type used to store the enumset.
226     fn enumset_repr(&self) -> SynTokenStream {
227         if let Some(explicit_mem_repr) = &self.explicit_mem_repr {
228             explicit_mem_repr.to_token_stream()
229         } else if self.max_discrim <= 7 {
230             quote! { u8 }
231         } else if self.max_discrim <= 15 {
232             quote! { u16 }
233         } else if self.max_discrim <= 31 {
234             quote! { u32 }
235         } else if self.max_discrim <= 63 {
236             quote! { u64 }
237         } else if self.max_discrim <= 127 {
238             quote! { u128 }
239         } else {
240             panic!("max_variant > 127?")
241         }
242     }
243     /// Computes the underlying type used to serialize the enumset.
244     #[cfg(feature = "serde")]
245     fn serde_repr(&self) -> SynTokenStream {
246         if let Some(serde_repr) = &self.explicit_serde_repr {
247             quote! { #serde_repr }
248         } else {
249             self.enumset_repr()
250         }
251     }
252
253     /// Returns a bitmask of all variants in the set.
254     fn all_variants(&self) -> u128 {
255         let mut accum = 0u128;
256         for variant in &self.variants {
257             assert!(variant.variant_repr <= 127);
258             accum |= 1u128 << variant.variant_repr as u128;
259         }
260         accum
261     }
262 }
263
264 /// Generates the actual `EnumSetType` impl.
265 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
266     let name = &info.name;
267
268     let enumset = match &info.crate_name {
269         Some(crate_name) => quote!(::#crate_name),
270         None => {
271             #[cfg(feature = "proc-macro-crate")]
272             {
273                 use proc_macro_crate::FoundCrate;
274
275                 let crate_name = proc_macro_crate::crate_name("enumset");
276                 match crate_name {
277                     Ok(FoundCrate::Name(name)) => {
278                         let ident = Ident::new(&name, Span::call_site());
279                         quote!(::#ident)
280                     }
281                     _ => quote!(::enumset),
282                 }
283             }
284
285             #[cfg(not(feature = "proc-macro-crate"))]
286             {
287                 quote!(::enumset)
288             }
289         }
290     };
291     let typed_enumset = quote!(#enumset::EnumSet<#name>);
292     let core = quote!(#enumset::__internal::core_export);
293
294     let repr = info.enumset_repr();
295     let all_variants = Literal::u128_unsuffixed(info.all_variants());
296
297     let ops = if info.no_ops {
298         quote! {}
299     } else {
300         quote! {
301             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
302                 type Output = #typed_enumset;
303                 fn sub(self, other: O) -> Self::Output {
304                     #enumset::EnumSet::only(self) - other.into()
305                 }
306             }
307             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
308                 type Output = #typed_enumset;
309                 fn bitand(self, other: O) -> Self::Output {
310                     #enumset::EnumSet::only(self) & other.into()
311                 }
312             }
313             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
314                 type Output = #typed_enumset;
315                 fn bitor(self, other: O) -> Self::Output {
316                     #enumset::EnumSet::only(self) | other.into()
317                 }
318             }
319             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
320                 type Output = #typed_enumset;
321                 fn bitxor(self, other: O) -> Self::Output {
322                     #enumset::EnumSet::only(self) ^ other.into()
323                 }
324             }
325             impl #core::ops::Not for #name {
326                 type Output = #typed_enumset;
327                 fn not(self) -> Self::Output {
328                     !#enumset::EnumSet::only(self)
329                 }
330             }
331             impl #core::cmp::PartialEq<#typed_enumset> for #name {
332                 fn eq(&self, other: &#typed_enumset) -> bool {
333                     #enumset::EnumSet::only(*self) == *other
334                 }
335             }
336         }
337     };
338
339     #[cfg(feature = "serde")]
340     let serde = quote!(#enumset::__internal::serde);
341
342     #[cfg(feature = "serde")]
343     let serde_ops = if info.serialize_as_list {
344         let expecting_str = format!("a list of {}", name);
345         quote! {
346             fn serialize<S: #serde::Serializer>(
347                 set: #enumset::EnumSet<#name>, ser: S,
348             ) -> #core::result::Result<S::Ok, S::Error> {
349                 use #serde::ser::SerializeSeq;
350                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
351                 for bit in set {
352                     seq.serialize_element(&bit)?;
353                 }
354                 seq.end()
355             }
356             fn deserialize<'de, D: #serde::Deserializer<'de>>(
357                 de: D,
358             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
359                 struct Visitor;
360                 impl <'de> #serde::de::Visitor<'de> for Visitor {
361                     type Value = #enumset::EnumSet<#name>;
362                     fn expecting(
363                         &self, formatter: &mut #core::fmt::Formatter,
364                     ) -> #core::fmt::Result {
365                         write!(formatter, #expecting_str)
366                     }
367                     fn visit_seq<A>(
368                         mut self, mut seq: A,
369                     ) -> #core::result::Result<Self::Value, A::Error> where
370                         A: #serde::de::SeqAccess<'de>
371                     {
372                         let mut accum = #enumset::EnumSet::<#name>::new();
373                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
374                             accum |= val;
375                         }
376                         #core::prelude::v1::Ok(accum)
377                     }
378                 }
379                 de.deserialize_seq(Visitor)
380             }
381         }
382     } else if info.serialize_as_map {
383         let expecting_str = format!("a map from {} to bool", name);
384         quote! {
385             fn serialize<S: #serde::Serializer>(
386                 set: #enumset::EnumSet<#name>, ser: S,
387             ) -> #core::result::Result<S::Ok, S::Error> {
388                 use #serde::ser::SerializeMap;
389                 let mut map = ser.serialize_map(#core::prelude::v1::Some(set.len()))?;
390                 for bit in set {
391                     map.serialize_entry(&bit, &true)?;
392                 }
393                 map.end()
394             }
395             fn deserialize<'de, D: #serde::Deserializer<'de>>(
396                 de: D,
397             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
398                 struct Visitor;
399                 impl <'de> #serde::de::Visitor<'de> for Visitor {
400                     type Value = #enumset::EnumSet<#name>;
401                     fn expecting(
402                         &self, formatter: &mut #core::fmt::Formatter,
403                     ) -> #core::fmt::Result {
404                         write!(formatter, #expecting_str)
405                     }
406                     fn visit_map<A>(
407                         mut self, mut map: A,
408                     ) -> #core::result::Result<Self::Value, A::Error> where
409                         A: #serde::de::MapAccess<'de>
410                     {
411                         let mut accum = #enumset::EnumSet::<#name>::new();
412                         while let #core::prelude::v1::Some((val, true)) = map.next_entry::<#name, bool>()? {
413                             accum |= val;
414                         }
415                         #core::prelude::v1::Ok(accum)
416                     }
417                 }
418                 de.deserialize_map(Visitor)
419             }
420         }
421     } else {
422         let serialize_repr = info.serde_repr();
423         let check_unknown = if info.serialize_deny_unknown {
424             quote! {
425                 if value & !#all_variants != 0 {
426                     use #serde::de::Error;
427                     return #core::prelude::v1::Err(
428                         D::Error::custom("enumset contains unknown bits")
429                     )
430                 }
431             }
432         } else {
433             quote! {}
434         };
435         quote! {
436             fn serialize<S: #serde::Serializer>(
437                 set: #enumset::EnumSet<#name>, ser: S,
438             ) -> #core::result::Result<S::Ok, S::Error> {
439                 #serde::Serialize::serialize(&(set.__priv_repr as #serialize_repr), ser)
440             }
441             fn deserialize<'de, D: #serde::Deserializer<'de>>(
442                 de: D,
443             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
444                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
445                 #check_unknown
446                 #core::prelude::v1::Ok(#enumset::EnumSet {
447                     __priv_repr: (value & #all_variants) as #repr,
448                 })
449             }
450         }
451     };
452
453     #[cfg(not(feature = "serde"))]
454     let serde_ops = quote! {};
455
456     let is_uninhabited = info.variants.is_empty();
457     let is_zst = info.variants.len() == 1;
458     let into_impl = if is_uninhabited {
459         quote! {
460             fn enum_into_u32(self) -> u32 {
461                 panic!(concat!(stringify!(#name), " is uninhabited."))
462             }
463             unsafe fn enum_from_u32(val: u32) -> Self {
464                 panic!(concat!(stringify!(#name), " is uninhabited."))
465             }
466         }
467     } else if is_zst {
468         let variant = &info.variants[0].name;
469         quote! {
470             fn enum_into_u32(self) -> u32 {
471                 self as u32
472             }
473             unsafe fn enum_from_u32(val: u32) -> Self {
474                 #name::#variant
475             }
476         }
477     } else {
478         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
479         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
480
481         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
482             .iter()
483             .map(|x| Ident::new(x, Span::call_site()))
484             .collect();
485         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
486             .iter()
487             .map(|x| Ident::new(x, Span::call_site()))
488             .collect();
489
490         quote! {
491             fn enum_into_u32(self) -> u32 {
492                 self as u32
493             }
494             unsafe fn enum_from_u32(val: u32) -> Self {
495                 // We put these in const fields so the branches they guard aren't generated even
496                 // on -O0
497                 #(const #const_field: bool =
498                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
499                 match val {
500                     // Every valid variant value has an explicit branch. If they get optimized out,
501                     // great. If the representation has changed somehow, and they don't, oh well,
502                     // there's still no UB.
503                     #(#variant_value => #name::#variant_name,)*
504                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
505                     // still unreachable.
506                     #(x if #const_field => {
507                         let x = x as #int_type;
508                         *(&x as *const _ as *const #name)
509                     })*
510                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
511                     // transmute, but, oh well.
512                     _ => #core::hint::unreachable_unchecked(),
513                 }
514             }
515         }
516     };
517
518     let eq_impl = if is_uninhabited {
519         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
520     } else {
521         quote!((*self as u32) == (*other as u32))
522     };
523
524     // used in the enum_set! macro `const fn`s.
525     let self_as_repr_mask = if is_uninhabited {
526         quote! { 0 } // impossible anyway
527     } else {
528         quote! { 1 << self as #repr }
529     };
530
531     let super_impls = if info.no_super_impls {
532         quote! {}
533     } else {
534         quote! {
535             impl #core::cmp::PartialEq for #name {
536                 fn eq(&self, other: &Self) -> bool {
537                     #eq_impl
538                 }
539             }
540             impl #core::cmp::Eq for #name { }
541             #[allow(clippy::expl_impl_clone_on_copy)]
542             impl #core::clone::Clone for #name {
543                 fn clone(&self) -> Self {
544                     *self
545                 }
546             }
547             impl #core::marker::Copy for #name { }
548         }
549     };
550
551     let impl_with_repr = if info.explicit_mem_repr.is_some() {
552         quote! {
553             unsafe impl #enumset::EnumSetTypeWithRepr for #name {
554                 type Repr = #repr;
555             }
556         }
557     } else {
558         quote! {}
559     };
560
561     quote! {
562         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
563             type Repr = #repr;
564             const ALL_BITS: Self::Repr = #all_variants;
565             #into_impl
566             #serde_ops
567         }
568
569         unsafe impl #enumset::EnumSetType for #name { }
570
571         #impl_with_repr
572         #super_impls
573
574         impl #name {
575             /// Creates a new enumset with only this variant.
576             #[deprecated(note = "This method is an internal implementation detail generated by \
577                                  the `enumset` crate's procedural macro. It should not be used \
578                                  directly. Use `EnumSet::only` instead.")]
579             #[doc(hidden)]
580             pub const fn __impl_enumset_internal__const_only(self) -> #enumset::EnumSet<#name> {
581                 #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
582             }
583
584             /// Creates a new enumset with this variant added.
585             #[deprecated(note = "This method is an internal implementation detail generated by \
586                                  the `enumset` crate's procedural macro. It should not be used \
587                                  directly. Use the `|` operator instead.")]
588             #[doc(hidden)]
589             pub const fn __impl_enumset_internal__const_merge(
590                 self, chain: #enumset::EnumSet<#name>,
591             ) -> #enumset::EnumSet<#name> {
592                 #enumset::EnumSet { __priv_repr: chain.__priv_repr | #self_as_repr_mask }
593             }
594         }
595
596         #ops
597     }
598 }
599
600 #[proc_macro_derive(EnumSetType, attributes(enumset))]
601 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
602     let input: DeriveInput = parse_macro_input!(input);
603     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
604         Ok(attrs) => attrs,
605         Err(e) => return e.write_errors().into(),
606     };
607     match derive_enum_set_type_0(input, attrs) {
608         Ok(v) => v,
609         Err(e) => e.to_compile_error().into(),
610     }
611 }
612 fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
613     if !input.generics.params.is_empty() {
614         error(
615             input.generics.span(),
616             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
617         )
618     } else if let Data::Enum(data) = &input.data {
619         let mut info = EnumSetInfo::new(&input, attrs);
620         for attr in &input.attrs {
621             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
622                 let meta: Ident = attr.parse_args()?;
623                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
624             }
625         }
626         for variant in &data.variants {
627             info.push_variant(variant)?;
628         }
629         info.validate()?;
630         Ok(enum_set_type_impl(info).into())
631     } else {
632         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
633     }
634 }