]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
cef5a39bdb196ac2a05e566c3e5e42da50732270
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit = "256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{Literal, Span, TokenStream as SynTokenStream};
8 use quote::*;
9 use std::{collections::HashSet, fmt::Display};
10 use syn::spanned::Spanned;
11 use syn::{Error, Result, *};
12
13 /// Helper function for emitting compile errors.
14 fn error<T>(span: Span, message: impl Display) -> Result<T> {
15     Err(Error::new(span, message))
16 }
17
18 /// Decodes the custom attributes for our custom derive.
19 #[derive(FromDeriveInput, Default)]
20 #[darling(attributes(enumset), default)]
21 struct EnumsetAttrs {
22     no_ops: bool,
23     no_super_impls: bool,
24     #[darling(default)]
25     repr: Option<String>,
26     serialize_as_list: bool,
27     serialize_deny_unknown: bool,
28     #[darling(default)]
29     serialize_repr: Option<String>,
30     #[darling(default)]
31     crate_name: Option<String>,
32 }
33
34 /// An variant in the enum set type.
35 struct EnumSetValue {
36     /// The name of the variant.
37     name: Ident,
38     /// The discriminant of the variant.
39     variant_repr: u32,
40 }
41
42 /// Stores information about the enum set type.
43 #[allow(dead_code)]
44 struct EnumSetInfo {
45     /// The name of the enum.
46     name: Ident,
47     /// The crate name to use.
48     crate_name: Option<Ident>,
49     /// The numeric type to represent the `EnumSet` as in memory.
50     explicit_mem_repr: Option<Ident>,
51     /// The numeric type to serialize the enum as.
52     explicit_serde_repr: Option<Ident>,
53     /// Whether the underlying repr of the enum supports negative values.
54     has_signed_repr: bool,
55     /// Whether the underlying repr of the enum supports values higher than 2^32.
56     has_large_repr: bool,
57     /// A list of variants in the enum.
58     variants: Vec<EnumSetValue>,
59
60     /// The highest encountered variant discriminant.
61     max_discrim: u32,
62     /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
63     cur_discrim: u32,
64     /// A list of variant names that are already in use.
65     used_variant_names: HashSet<String>,
66     /// A list of variant discriminants that are already in use.
67     used_discriminants: HashSet<u32>,
68
69     /// Avoid generating operator overloads on the enum type.
70     no_ops: bool,
71     /// Avoid generating implementations for `Clone`, `Copy`, `Eq`, and `PartialEq`.
72     no_super_impls: bool,
73     /// Serialize the enum as a list.
74     serialize_as_list: bool,
75     /// Disallow unknown bits while deserializing the enum.
76     serialize_deny_unknown: bool,
77 }
78 impl EnumSetInfo {
79     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
80         EnumSetInfo {
81             name: input.ident.clone(),
82             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
83             explicit_mem_repr: attrs.repr.map(|x| Ident::new(&x, Span::call_site())),
84             explicit_serde_repr: attrs
85                 .serialize_repr
86                 .map(|x| Ident::new(&x, Span::call_site())),
87             has_signed_repr: false,
88             has_large_repr: false,
89             variants: Vec::new(),
90             max_discrim: 0,
91             cur_discrim: 0,
92             used_variant_names: HashSet::new(),
93             used_discriminants: HashSet::new(),
94             no_ops: attrs.no_ops,
95             no_super_impls: attrs.no_super_impls,
96             serialize_as_list: attrs.serialize_as_list,
97             serialize_deny_unknown: attrs.serialize_deny_unknown,
98         }
99     }
100
101     /// Sets an explicit repr for the enumset.
102     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
103         // Check whether the repr is supported, and if so, set some flags for better error
104         // messages later on.
105         match repr {
106             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
107             "usize" | "u64" | "u128" => {
108                 self.has_large_repr = true;
109                 Ok(())
110             }
111             "i8" | "i16" | "i32" => {
112                 self.has_signed_repr = true;
113                 Ok(())
114             }
115             "isize" | "i64" | "i128" => {
116                 self.has_signed_repr = true;
117                 self.has_large_repr = true;
118                 Ok(())
119             }
120             _ => error(attr_span, "Unsupported repr."),
121         }
122     }
123     /// Adds a variant to the enumset.
124     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
125         if self.used_variant_names.contains(&variant.ident.to_string()) {
126             error(variant.span(), "Duplicated variant name.")
127         } else if let Fields::Unit = variant.fields {
128             // Parse the discriminant.
129             if let Some((_, expr)) = &variant.discriminant {
130                 let discriminant_fail_message = format!(
131                     "Enum set discriminants must be `u32`s.{}",
132                     if self.has_signed_repr || self.has_large_repr {
133                         format!(
134                             " ({} discrimiants are still unsupported with reprs that allow them.)",
135                             if self.has_large_repr {
136                                 "larger"
137                             } else if self.has_signed_repr {
138                                 "negative"
139                             } else {
140                                 "larger or negative"
141                             }
142                         )
143                     } else {
144                         String::new()
145                     },
146                 );
147                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
148                     match i.base10_parse() {
149                         Ok(val) => self.cur_discrim = val,
150                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
151                     }
152                 } else {
153                     error(variant.span(), &discriminant_fail_message)?;
154                 }
155             }
156
157             // Validate the discriminant.
158             let discriminant = self.cur_discrim;
159             if discriminant >= 128 {
160                 let message = if self.variants.len() <= 127 {
161                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
162                 } else {
163                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
164                 };
165                 error(variant.span(), message)?;
166             }
167             if self.used_discriminants.contains(&discriminant) {
168                 error(variant.span(), "Duplicated enum discriminant.")?;
169             }
170
171             // Add the variant to the info.
172             self.cur_discrim += 1;
173             if discriminant > self.max_discrim {
174                 self.max_discrim = discriminant;
175             }
176             self.variants
177                 .push(EnumSetValue { name: variant.ident.clone(), variant_repr: discriminant });
178             self.used_variant_names.insert(variant.ident.to_string());
179             self.used_discriminants.insert(discriminant);
180
181             Ok(())
182         } else {
183             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
184         }
185     }
186     /// Validate the enumset type.
187     fn validate(&self) -> Result<()> {
188         fn do_check(ty: &str, max_discrim: u32, what: &str) -> Result<()> {
189             let is_overflowed = match ty {
190                 "u8" => max_discrim >= 8,
191                 "u16" => max_discrim >= 16,
192                 "u32" => max_discrim >= 32,
193                 "u64" => max_discrim >= 64,
194                 "u128" => max_discrim >= 128,
195                 _ => error(
196                     Span::call_site(),
197                     format!("Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for {what}."),
198                 )?,
199             };
200             if is_overflowed {
201                 error(Span::call_site(), format!("{what} cannot be smaller than bitset."))?;
202             }
203             Ok(())
204         }
205
206         // Check if all bits of the bitset can fit in the serialization representation.
207         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
208             do_check(&explicit_serde_repr.to_string(), self.max_discrim, "serialize_repr")?;
209         }
210
211         // Check if all bits of the bitset can fit in the memory representation, if one was given.
212         if let Some(explicit_mem_repr) = &self.explicit_mem_repr {
213             do_check(&explicit_mem_repr.to_string(), self.max_discrim, "repr")?;
214         }
215         Ok(())
216     }
217
218     /// Computes the underlying type used to store the enumset.
219     fn enumset_repr(&self) -> SynTokenStream {
220         if let Some(explicit_mem_repr) = &self.explicit_mem_repr {
221             explicit_mem_repr.to_token_stream()
222         } else if self.max_discrim <= 7 {
223             quote! { u8 }
224         } else if self.max_discrim <= 15 {
225             quote! { u16 }
226         } else if self.max_discrim <= 31 {
227             quote! { u32 }
228         } else if self.max_discrim <= 63 {
229             quote! { u64 }
230         } else if self.max_discrim <= 127 {
231             quote! { u128 }
232         } else {
233             panic!("max_variant > 127?")
234         }
235     }
236     /// Computes the underlying type used to serialize the enumset.
237     #[cfg(feature = "serde")]
238     fn serde_repr(&self) -> SynTokenStream {
239         if let Some(serde_repr) = &self.explicit_serde_repr {
240             quote! { #serde_repr }
241         } else {
242             self.enumset_repr()
243         }
244     }
245
246     /// Returns a bitmask of all variants in the set.
247     fn all_variants(&self) -> u128 {
248         let mut accum = 0u128;
249         for variant in &self.variants {
250             assert!(variant.variant_repr <= 127);
251             accum |= 1u128 << variant.variant_repr as u128;
252         }
253         accum
254     }
255 }
256
257 /// Generates the actual `EnumSetType` impl.
258 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
259     let name = &info.name;
260
261     let enumset = match &info.crate_name {
262         Some(crate_name) => quote!(::#crate_name),
263         None => {
264             #[cfg(feature = "proc-macro-crate")]
265             {
266                 use proc_macro_crate::FoundCrate;
267
268                 let crate_name = proc_macro_crate::crate_name("enumset");
269                 match crate_name {
270                     Ok(FoundCrate::Name(name)) => {
271                         let ident = Ident::new(&name, Span::call_site());
272                         quote!(::#ident)
273                     }
274                     _ => quote!(::enumset),
275                 }
276             }
277
278             #[cfg(not(feature = "proc-macro-crate"))]
279             {
280                 quote!(::enumset)
281             }
282         }
283     };
284     let typed_enumset = quote!(#enumset::EnumSet<#name>);
285     let core = quote!(#enumset::__internal::core_export);
286
287     let repr = info.enumset_repr();
288     let all_variants = Literal::u128_unsuffixed(info.all_variants());
289
290     let ops = if info.no_ops {
291         quote! {}
292     } else {
293         quote! {
294             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
295                 type Output = #typed_enumset;
296                 fn sub(self, other: O) -> Self::Output {
297                     #enumset::EnumSet::only(self) - other.into()
298                 }
299             }
300             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
301                 type Output = #typed_enumset;
302                 fn bitand(self, other: O) -> Self::Output {
303                     #enumset::EnumSet::only(self) & other.into()
304                 }
305             }
306             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
307                 type Output = #typed_enumset;
308                 fn bitor(self, other: O) -> Self::Output {
309                     #enumset::EnumSet::only(self) | other.into()
310                 }
311             }
312             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
313                 type Output = #typed_enumset;
314                 fn bitxor(self, other: O) -> Self::Output {
315                     #enumset::EnumSet::only(self) ^ other.into()
316                 }
317             }
318             impl #core::ops::Not for #name {
319                 type Output = #typed_enumset;
320                 fn not(self) -> Self::Output {
321                     !#enumset::EnumSet::only(self)
322                 }
323             }
324             impl #core::cmp::PartialEq<#typed_enumset> for #name {
325                 fn eq(&self, other: &#typed_enumset) -> bool {
326                     #enumset::EnumSet::only(*self) == *other
327                 }
328             }
329         }
330     };
331
332     #[cfg(feature = "serde")]
333     let serde = quote!(#enumset::__internal::serde);
334
335     #[cfg(feature = "serde")]
336     let serde_ops = if info.serialize_as_list {
337         let expecting_str = format!("a list of {}", name);
338         quote! {
339             fn serialize<S: #serde::Serializer>(
340                 set: #enumset::EnumSet<#name>, ser: S,
341             ) -> #core::result::Result<S::Ok, S::Error> {
342                 use #serde::ser::SerializeSeq;
343                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
344                 for bit in set {
345                     seq.serialize_element(&bit)?;
346                 }
347                 seq.end()
348             }
349             fn deserialize<'de, D: #serde::Deserializer<'de>>(
350                 de: D,
351             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
352                 struct Visitor;
353                 impl <'de> #serde::de::Visitor<'de> for Visitor {
354                     type Value = #enumset::EnumSet<#name>;
355                     fn expecting(
356                         &self, formatter: &mut #core::fmt::Formatter,
357                     ) -> #core::fmt::Result {
358                         write!(formatter, #expecting_str)
359                     }
360                     fn visit_seq<A>(
361                         mut self, mut seq: A,
362                     ) -> #core::result::Result<Self::Value, A::Error> where
363                         A: #serde::de::SeqAccess<'de>
364                     {
365                         let mut accum = #enumset::EnumSet::<#name>::new();
366                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
367                             accum |= val;
368                         }
369                         #core::prelude::v1::Ok(accum)
370                     }
371                 }
372                 de.deserialize_seq(Visitor)
373             }
374         }
375     } else {
376         let serialize_repr = info.serde_repr();
377         let check_unknown = if info.serialize_deny_unknown {
378             quote! {
379                 if value & !#all_variants != 0 {
380                     use #serde::de::Error;
381                     return #core::prelude::v1::Err(
382                         D::Error::custom("enumset contains unknown bits")
383                     )
384                 }
385             }
386         } else {
387             quote! {}
388         };
389         quote! {
390             fn serialize<S: #serde::Serializer>(
391                 set: #enumset::EnumSet<#name>, ser: S,
392             ) -> #core::result::Result<S::Ok, S::Error> {
393                 #serde::Serialize::serialize(&(set.__priv_repr as #serialize_repr), ser)
394             }
395             fn deserialize<'de, D: #serde::Deserializer<'de>>(
396                 de: D,
397             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
398                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
399                 #check_unknown
400                 #core::prelude::v1::Ok(#enumset::EnumSet {
401                     __priv_repr: (value & #all_variants) as #repr,
402                 })
403             }
404         }
405     };
406
407     #[cfg(not(feature = "serde"))]
408     let serde_ops = quote! {};
409
410     let is_uninhabited = info.variants.is_empty();
411     let is_zst = info.variants.len() == 1;
412     let into_impl = if is_uninhabited {
413         quote! {
414             fn enum_into_u32(self) -> u32 {
415                 panic!(concat!(stringify!(#name), " is uninhabited."))
416             }
417             unsafe fn enum_from_u32(val: u32) -> Self {
418                 panic!(concat!(stringify!(#name), " is uninhabited."))
419             }
420         }
421     } else if is_zst {
422         let variant = &info.variants[0].name;
423         quote! {
424             fn enum_into_u32(self) -> u32 {
425                 self as u32
426             }
427             unsafe fn enum_from_u32(val: u32) -> Self {
428                 #name::#variant
429             }
430         }
431     } else {
432         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
433         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
434
435         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
436             .iter()
437             .map(|x| Ident::new(x, Span::call_site()))
438             .collect();
439         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
440             .iter()
441             .map(|x| Ident::new(x, Span::call_site()))
442             .collect();
443
444         quote! {
445             fn enum_into_u32(self) -> u32 {
446                 self as u32
447             }
448             unsafe fn enum_from_u32(val: u32) -> Self {
449                 // We put these in const fields so the branches they guard aren't generated even
450                 // on -O0
451                 #(const #const_field: bool =
452                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
453                 match val {
454                     // Every valid variant value has an explicit branch. If they get optimized out,
455                     // great. If the representation has changed somehow, and they don't, oh well,
456                     // there's still no UB.
457                     #(#variant_value => #name::#variant_name,)*
458                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
459                     // still unreachable.
460                     #(x if #const_field => {
461                         let x = x as #int_type;
462                         *(&x as *const _ as *const #name)
463                     })*
464                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
465                     // transmute, but, oh well.
466                     _ => #core::hint::unreachable_unchecked(),
467                 }
468             }
469         }
470     };
471
472     let eq_impl = if is_uninhabited {
473         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
474     } else {
475         quote!((*self as u32) == (*other as u32))
476     };
477
478     // used in the enum_set! macro `const fn`s.
479     let self_as_repr_mask = if is_uninhabited {
480         quote! { 0 } // impossible anyway
481     } else {
482         quote! { 1 << self as #repr }
483     };
484
485     let super_impls = if info.no_super_impls {
486         quote! {}
487     } else {
488         quote! {
489             impl #core::cmp::PartialEq for #name {
490                 fn eq(&self, other: &Self) -> bool {
491                     #eq_impl
492                 }
493             }
494             impl #core::cmp::Eq for #name { }
495             #[allow(clippy::expl_impl_clone_on_copy)]
496             impl #core::clone::Clone for #name {
497                 fn clone(&self) -> Self {
498                     *self
499                 }
500             }
501             impl #core::marker::Copy for #name { }
502         }
503     };
504
505     let impl_with_repr = if info.explicit_mem_repr.is_some() {
506         quote! {
507             unsafe impl #enumset::EnumSetTypeWithRepr for #name {
508                 type Repr = #repr;
509             }
510         }
511     } else {
512         quote! {}
513     };
514
515     quote! {
516         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
517             type Repr = #repr;
518             const ALL_BITS: Self::Repr = #all_variants;
519             #into_impl
520             #serde_ops
521         }
522
523         unsafe impl #enumset::EnumSetType for #name { }
524
525         #impl_with_repr
526         #super_impls
527
528         impl #name {
529             /// Creates a new enumset with only this variant.
530             #[deprecated(note = "This method is an internal implementation detail generated by \
531                                  the `enumset` crate's procedural macro. It should not be used \
532                                  directly. Use `EnumSet::only` instead.")]
533             #[doc(hidden)]
534             pub const fn __impl_enumset_internal__const_only(self) -> #enumset::EnumSet<#name> {
535                 #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
536             }
537
538             /// Creates a new enumset with this variant added.
539             #[deprecated(note = "This method is an internal implementation detail generated by \
540                                  the `enumset` crate's procedural macro. It should not be used \
541                                  directly. Use the `|` operator instead.")]
542             #[doc(hidden)]
543             pub const fn __impl_enumset_internal__const_merge(
544                 self, chain: #enumset::EnumSet<#name>,
545             ) -> #enumset::EnumSet<#name> {
546                 #enumset::EnumSet { __priv_repr: chain.__priv_repr | #self_as_repr_mask }
547             }
548         }
549
550         #ops
551     }
552 }
553
554 #[proc_macro_derive(EnumSetType, attributes(enumset))]
555 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
556     let input: DeriveInput = parse_macro_input!(input);
557     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
558         Ok(attrs) => attrs,
559         Err(e) => return e.write_errors().into(),
560     };
561     match derive_enum_set_type_0(input, attrs) {
562         Ok(v) => v,
563         Err(e) => e.to_compile_error().into(),
564     }
565 }
566 fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
567     if !input.generics.params.is_empty() {
568         error(
569             input.generics.span(),
570             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
571         )
572     } else if let Data::Enum(data) = &input.data {
573         let mut info = EnumSetInfo::new(&input, attrs);
574         for attr in &input.attrs {
575             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
576                 let meta: Ident = attr.parse_args()?;
577                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
578             }
579         }
580         for variant in &data.variants {
581             info.push_variant(variant)?;
582         }
583         info.validate()?;
584         Ok(enum_set_type_impl(info).into())
585     } else {
586         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
587     }
588 }