]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
88ac367e0ad914cf317bc046bebc9a51e1c655e2
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit = "256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{Literal, Span, TokenStream as SynTokenStream};
8 use quote::*;
9 use std::{collections::HashSet, fmt::Display};
10 use syn::spanned::Spanned;
11 use syn::{Error, Result, *};
12
13 /// Helper function for emitting compile errors.
14 fn error<T>(span: Span, message: impl Display) -> Result<T> {
15     Err(Error::new(span, message))
16 }
17
18 /// Decodes the custom attributes for our custom derive.
19 #[derive(FromDeriveInput, Default)]
20 #[darling(attributes(enumset), default)]
21 struct EnumsetAttrs {
22     no_ops: bool,
23     no_super_impls: bool,
24     #[darling(default)]
25     repr: Option<String>,
26     serialize_as_list: bool,
27     serialize_deny_unknown: bool,
28     #[darling(default)]
29     serialize_repr: Option<String>,
30     #[darling(default)]
31     crate_name: Option<String>,
32 }
33
34 /// An variant in the enum set type.
35 struct EnumSetValue {
36     /// The name of the variant.
37     name: Ident,
38     /// The discriminant of the variant.
39     variant_repr: u32,
40 }
41
42 /// Stores information about the enum set type.
43 #[allow(dead_code)]
44 struct EnumSetInfo {
45     /// The name of the enum.
46     name: Ident,
47     /// The crate name to use.
48     crate_name: Option<Ident>,
49     /// The numeric type to represent the `EnumSet` as in memory.
50     explicit_mem_repr: Option<Ident>,
51     /// The numeric type to serialize the enum as.
52     explicit_serde_repr: Option<Ident>,
53     /// Whether the underlying repr of the enum supports negative values.
54     has_signed_repr: bool,
55     /// Whether the underlying repr of the enum supports values higher than 2^32.
56     has_large_repr: bool,
57     /// A list of variants in the enum.
58     variants: Vec<EnumSetValue>,
59
60     /// The highest encountered variant discriminant.
61     max_discrim: u32,
62     /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
63     cur_discrim: u32,
64     /// A list of variant names that are already in use.
65     used_variant_names: HashSet<String>,
66     /// A list of variant discriminants that are already in use.
67     used_discriminants: HashSet<u32>,
68
69     /// Avoid generating operator overloads on the enum type.
70     no_ops: bool,
71     /// Avoid generating implementations for `Clone`, `Copy`, `Eq`, and `PartialEq`.
72     no_super_impls: bool,
73     /// Serialize the enum as a list.
74     serialize_as_list: bool,
75     /// Disallow unknown bits while deserializing the enum.
76     serialize_deny_unknown: bool,
77 }
78 impl EnumSetInfo {
79     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
80         EnumSetInfo {
81             name: input.ident.clone(),
82             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
83             explicit_mem_repr: attrs.repr.map(|x| Ident::new(&x, Span::call_site())),
84             explicit_serde_repr: attrs
85                 .serialize_repr
86                 .map(|x| Ident::new(&x, Span::call_site())),
87             has_signed_repr: false,
88             has_large_repr: false,
89             variants: Vec::new(),
90             max_discrim: 0,
91             cur_discrim: 0,
92             used_variant_names: HashSet::new(),
93             used_discriminants: HashSet::new(),
94             no_ops: attrs.no_ops,
95             no_super_impls: attrs.no_super_impls,
96             serialize_as_list: attrs.serialize_as_list,
97             serialize_deny_unknown: attrs.serialize_deny_unknown,
98         }
99     }
100
101     /// Sets an explicit repr for the enumset.
102     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
103         // Check whether the repr is supported, and if so, set some flags for better error
104         // messages later on.
105         match repr {
106             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
107             "usize" | "u64" | "u128" => {
108                 self.has_large_repr = true;
109                 Ok(())
110             }
111             "i8" | "i16" | "i32" => {
112                 self.has_signed_repr = true;
113                 Ok(())
114             }
115             "isize" | "i64" | "i128" => {
116                 self.has_signed_repr = true;
117                 self.has_large_repr = true;
118                 Ok(())
119             }
120             _ => error(attr_span, "Unsupported repr."),
121         }
122     }
123     /// Adds a variant to the enumset.
124     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
125         if self.used_variant_names.contains(&variant.ident.to_string()) {
126             error(variant.span(), "Duplicated variant name.")
127         } else if let Fields::Unit = variant.fields {
128             // Parse the discriminant.
129             if let Some((_, expr)) = &variant.discriminant {
130                 let discriminant_fail_message = format!(
131                     "Enum set discriminants must be `u32`s.{}",
132                     if self.has_signed_repr || self.has_large_repr {
133                         format!(
134                             " ({} discrimiants are still unsupported with reprs that allow them.)",
135                             if self.has_large_repr {
136                                 "larger"
137                             } else if self.has_signed_repr {
138                                 "negative"
139                             } else {
140                                 "larger or negative"
141                             }
142                         )
143                     } else {
144                         String::new()
145                     },
146                 );
147                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
148                     match i.base10_parse() {
149                         Ok(val) => self.cur_discrim = val,
150                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
151                     }
152                 } else {
153                     error(variant.span(), &discriminant_fail_message)?;
154                 }
155             }
156
157             // Validate the discriminant.
158             let discriminant = self.cur_discrim;
159             if discriminant >= 128 {
160                 let message = if self.variants.len() <= 127 {
161                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
162                 } else {
163                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
164                 };
165                 error(variant.span(), message)?;
166             }
167             if self.used_discriminants.contains(&discriminant) {
168                 error(variant.span(), "Duplicated enum discriminant.")?;
169             }
170
171             // Add the variant to the info.
172             self.cur_discrim += 1;
173             if discriminant > self.max_discrim {
174                 self.max_discrim = discriminant;
175             }
176             self.variants
177                 .push(EnumSetValue { name: variant.ident.clone(), variant_repr: discriminant });
178             self.used_variant_names.insert(variant.ident.to_string());
179             self.used_discriminants.insert(discriminant);
180
181             Ok(())
182         } else {
183             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
184         }
185     }
186     /// Validate the enumset type.
187     fn validate(&self) -> Result<()> {
188         fn do_check(ty: &str, max_discrim: u32, what: &str) -> Result<()> {
189             let is_overflowed = match ty {
190                 "u8" => max_discrim >= 8,
191                 "u16" => max_discrim >= 16,
192                 "u32" => max_discrim >= 32,
193                 "u64" => max_discrim >= 64,
194                 "u128" => max_discrim >= 128,
195                 _ => error(
196                     Span::call_site(),
197                     format!(
198                         "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for {}.",
199                         what
200                     ),
201                 )?,
202             };
203             if is_overflowed {
204                 error(Span::call_site(), format!("{} cannot be smaller than bitset.", what))?;
205             }
206             Ok(())
207         }
208
209         // Check if all bits of the bitset can fit in the serialization representation.
210         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
211             do_check(&explicit_serde_repr.to_string(), self.max_discrim, "serialize_repr")?;
212         }
213
214         // Check if all bits of the bitset can fit in the memory representation, if one was given.
215         if let Some(explicit_mem_repr) = &self.explicit_mem_repr {
216             do_check(&explicit_mem_repr.to_string(), self.max_discrim, "repr")?;
217         }
218         Ok(())
219     }
220
221     /// Computes the underlying type used to store the enumset.
222     fn enumset_repr(&self) -> SynTokenStream {
223         if let Some(explicit_mem_repr) = &self.explicit_mem_repr {
224             explicit_mem_repr.to_token_stream()
225         } else if self.max_discrim <= 7 {
226             quote! { u8 }
227         } else if self.max_discrim <= 15 {
228             quote! { u16 }
229         } else if self.max_discrim <= 31 {
230             quote! { u32 }
231         } else if self.max_discrim <= 63 {
232             quote! { u64 }
233         } else if self.max_discrim <= 127 {
234             quote! { u128 }
235         } else {
236             panic!("max_variant > 127?")
237         }
238     }
239     /// Computes the underlying type used to serialize the enumset.
240     #[cfg(feature = "serde")]
241     fn serde_repr(&self) -> SynTokenStream {
242         if let Some(serde_repr) = &self.explicit_serde_repr {
243             quote! { #serde_repr }
244         } else {
245             self.enumset_repr()
246         }
247     }
248
249     /// Returns a bitmask of all variants in the set.
250     fn all_variants(&self) -> u128 {
251         let mut accum = 0u128;
252         for variant in &self.variants {
253             assert!(variant.variant_repr <= 127);
254             accum |= 1u128 << variant.variant_repr as u128;
255         }
256         accum
257     }
258 }
259
260 /// Generates the actual `EnumSetType` impl.
261 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
262     let name = &info.name;
263
264     let enumset = match &info.crate_name {
265         Some(crate_name) => quote!(::#crate_name),
266         None => {
267             #[cfg(feature = "proc-macro-crate")]
268             {
269                 use proc_macro_crate::FoundCrate;
270
271                 let crate_name = proc_macro_crate::crate_name("enumset");
272                 match crate_name {
273                     Ok(FoundCrate::Name(name)) => {
274                         let ident = Ident::new(&name, Span::call_site());
275                         quote!(::#ident)
276                     }
277                     _ => quote!(::enumset),
278                 }
279             }
280
281             #[cfg(not(feature = "proc-macro-crate"))]
282             {
283                 quote!(::enumset)
284             }
285         }
286     };
287     let typed_enumset = quote!(#enumset::EnumSet<#name>);
288     let core = quote!(#enumset::__internal::core_export);
289
290     let repr = info.enumset_repr();
291     let all_variants = Literal::u128_unsuffixed(info.all_variants());
292
293     let ops = if info.no_ops {
294         quote! {}
295     } else {
296         quote! {
297             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
298                 type Output = #typed_enumset;
299                 fn sub(self, other: O) -> Self::Output {
300                     #enumset::EnumSet::only(self) - other.into()
301                 }
302             }
303             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
304                 type Output = #typed_enumset;
305                 fn bitand(self, other: O) -> Self::Output {
306                     #enumset::EnumSet::only(self) & other.into()
307                 }
308             }
309             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
310                 type Output = #typed_enumset;
311                 fn bitor(self, other: O) -> Self::Output {
312                     #enumset::EnumSet::only(self) | other.into()
313                 }
314             }
315             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
316                 type Output = #typed_enumset;
317                 fn bitxor(self, other: O) -> Self::Output {
318                     #enumset::EnumSet::only(self) ^ other.into()
319                 }
320             }
321             impl #core::ops::Not for #name {
322                 type Output = #typed_enumset;
323                 fn not(self) -> Self::Output {
324                     !#enumset::EnumSet::only(self)
325                 }
326             }
327             impl #core::cmp::PartialEq<#typed_enumset> for #name {
328                 fn eq(&self, other: &#typed_enumset) -> bool {
329                     #enumset::EnumSet::only(*self) == *other
330                 }
331             }
332         }
333     };
334
335     #[cfg(feature = "serde")]
336     let serde = quote!(#enumset::__internal::serde);
337
338     #[cfg(feature = "serde")]
339     let serde_ops = if info.serialize_as_list {
340         let expecting_str = format!("a list of {}", name);
341         quote! {
342             fn serialize<S: #serde::Serializer>(
343                 set: #enumset::EnumSet<#name>, ser: S,
344             ) -> #core::result::Result<S::Ok, S::Error> {
345                 use #serde::ser::SerializeSeq;
346                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
347                 for bit in set {
348                     seq.serialize_element(&bit)?;
349                 }
350                 seq.end()
351             }
352             fn deserialize<'de, D: #serde::Deserializer<'de>>(
353                 de: D,
354             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
355                 struct Visitor;
356                 impl <'de> #serde::de::Visitor<'de> for Visitor {
357                     type Value = #enumset::EnumSet<#name>;
358                     fn expecting(
359                         &self, formatter: &mut #core::fmt::Formatter,
360                     ) -> #core::fmt::Result {
361                         write!(formatter, #expecting_str)
362                     }
363                     fn visit_seq<A>(
364                         mut self, mut seq: A,
365                     ) -> #core::result::Result<Self::Value, A::Error> where
366                         A: #serde::de::SeqAccess<'de>
367                     {
368                         let mut accum = #enumset::EnumSet::<#name>::new();
369                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
370                             accum |= val;
371                         }
372                         #core::prelude::v1::Ok(accum)
373                     }
374                 }
375                 de.deserialize_seq(Visitor)
376             }
377         }
378     } else {
379         let serialize_repr = info.serde_repr();
380         let check_unknown = if info.serialize_deny_unknown {
381             quote! {
382                 if value & !#all_variants != 0 {
383                     use #serde::de::Error;
384                     return #core::prelude::v1::Err(
385                         D::Error::custom("enumset contains unknown bits")
386                     )
387                 }
388             }
389         } else {
390             quote! {}
391         };
392         quote! {
393             fn serialize<S: #serde::Serializer>(
394                 set: #enumset::EnumSet<#name>, ser: S,
395             ) -> #core::result::Result<S::Ok, S::Error> {
396                 #serde::Serialize::serialize(&(set.__priv_repr as #serialize_repr), ser)
397             }
398             fn deserialize<'de, D: #serde::Deserializer<'de>>(
399                 de: D,
400             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
401                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
402                 #check_unknown
403                 #core::prelude::v1::Ok(#enumset::EnumSet {
404                     __priv_repr: (value & #all_variants) as #repr,
405                 })
406             }
407         }
408     };
409
410     #[cfg(not(feature = "serde"))]
411     let serde_ops = quote! {};
412
413     let is_uninhabited = info.variants.is_empty();
414     let is_zst = info.variants.len() == 1;
415     let into_impl = if is_uninhabited {
416         quote! {
417             fn enum_into_u32(self) -> u32 {
418                 panic!(concat!(stringify!(#name), " is uninhabited."))
419             }
420             unsafe fn enum_from_u32(val: u32) -> Self {
421                 panic!(concat!(stringify!(#name), " is uninhabited."))
422             }
423         }
424     } else if is_zst {
425         let variant = &info.variants[0].name;
426         quote! {
427             fn enum_into_u32(self) -> u32 {
428                 self as u32
429             }
430             unsafe fn enum_from_u32(val: u32) -> Self {
431                 #name::#variant
432             }
433         }
434     } else {
435         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
436         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
437
438         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
439             .iter()
440             .map(|x| Ident::new(x, Span::call_site()))
441             .collect();
442         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
443             .iter()
444             .map(|x| Ident::new(x, Span::call_site()))
445             .collect();
446
447         quote! {
448             fn enum_into_u32(self) -> u32 {
449                 self as u32
450             }
451             unsafe fn enum_from_u32(val: u32) -> Self {
452                 // We put these in const fields so the branches they guard aren't generated even
453                 // on -O0
454                 #(const #const_field: bool =
455                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
456                 match val {
457                     // Every valid variant value has an explicit branch. If they get optimized out,
458                     // great. If the representation has changed somehow, and they don't, oh well,
459                     // there's still no UB.
460                     #(#variant_value => #name::#variant_name,)*
461                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
462                     // still unreachable.
463                     #(x if #const_field => {
464                         let x = x as #int_type;
465                         *(&x as *const _ as *const #name)
466                     })*
467                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
468                     // transmute, but, oh well.
469                     _ => #core::hint::unreachable_unchecked(),
470                 }
471             }
472         }
473     };
474
475     let eq_impl = if is_uninhabited {
476         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
477     } else {
478         quote!((*self as u32) == (*other as u32))
479     };
480
481     // used in the enum_set! macro `const fn`s.
482     let self_as_repr_mask = if is_uninhabited {
483         quote! { 0 } // impossible anyway
484     } else {
485         quote! { 1 << self as #repr }
486     };
487
488     let super_impls = if info.no_super_impls {
489         quote! {}
490     } else {
491         quote! {
492             impl #core::cmp::PartialEq for #name {
493                 fn eq(&self, other: &Self) -> bool {
494                     #eq_impl
495                 }
496             }
497             impl #core::cmp::Eq for #name { }
498             #[allow(clippy::expl_impl_clone_on_copy)]
499             impl #core::clone::Clone for #name {
500                 fn clone(&self) -> Self {
501                     *self
502                 }
503             }
504             impl #core::marker::Copy for #name { }
505         }
506     };
507
508     let impl_with_repr = if info.explicit_mem_repr.is_some() {
509         quote! {
510             unsafe impl #enumset::EnumSetTypeWithRepr for #name {
511                 type Repr = #repr;
512             }
513         }
514     } else {
515         quote! {}
516     };
517
518     quote! {
519         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
520             type Repr = #repr;
521             const ALL_BITS: Self::Repr = #all_variants;
522             #into_impl
523             #serde_ops
524         }
525
526         unsafe impl #enumset::EnumSetType for #name { }
527
528         #impl_with_repr
529         #super_impls
530
531         impl #name {
532             /// Creates a new enumset with only this variant.
533             #[deprecated(note = "This method is an internal implementation detail generated by \
534                                  the `enumset` crate's procedural macro. It should not be used \
535                                  directly. Use `EnumSet::only` instead.")]
536             #[doc(hidden)]
537             pub const fn __impl_enumset_internal__const_only(self) -> #enumset::EnumSet<#name> {
538                 #enumset::EnumSet { __priv_repr: #self_as_repr_mask }
539             }
540
541             /// Creates a new enumset with this variant added.
542             #[deprecated(note = "This method is an internal implementation detail generated by \
543                                  the `enumset` crate's procedural macro. It should not be used \
544                                  directly. Use the `|` operator instead.")]
545             #[doc(hidden)]
546             pub const fn __impl_enumset_internal__const_merge(
547                 self, chain: #enumset::EnumSet<#name>,
548             ) -> #enumset::EnumSet<#name> {
549                 #enumset::EnumSet { __priv_repr: chain.__priv_repr | #self_as_repr_mask }
550             }
551         }
552
553         #ops
554     }
555 }
556
557 #[proc_macro_derive(EnumSetType, attributes(enumset))]
558 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
559     let input: DeriveInput = parse_macro_input!(input);
560     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
561         Ok(attrs) => attrs,
562         Err(e) => return e.write_errors().into(),
563     };
564     match derive_enum_set_type_0(input, attrs) {
565         Ok(v) => v,
566         Err(e) => e.to_compile_error().into(),
567     }
568 }
569 fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
570     if !input.generics.params.is_empty() {
571         error(
572             input.generics.span(),
573             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
574         )
575     } else if let Data::Enum(data) = &input.data {
576         let mut info = EnumSetInfo::new(&input, attrs);
577         for attr in &input.attrs {
578             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
579                 let meta: Ident = attr.parse_args()?;
580                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
581             }
582         }
583         for variant in &data.variants {
584             info.push_variant(variant)?;
585         }
586         info.validate()?;
587         Ok(enum_set_type_impl(info).into())
588     } else {
589         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
590     }
591 }