]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Add support for renaming the enumset crate.
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal};
8 use std::collections::HashSet;
9 use syn::{*, Result, Error};
10 use syn::export::Span;
11 use syn::spanned::Spanned;
12 use quote::*;
13
14 /// Helper function for emitting compile errors.
15 fn error<T>(span: Span, message: &str) -> Result<T> {
16     Err(Error::new(span, message))
17 }
18
19 /// Decodes the custom attributes for our custom derive.
20 #[derive(FromDeriveInput, Default)]
21 #[darling(attributes(enumset), default)]
22 struct EnumsetAttrs {
23     no_ops: bool,
24     serialize_as_list: bool,
25     serialize_deny_unknown: bool,
26     #[darling(default)]
27     serialize_repr: Option<String>,
28     #[darling(default)]
29     crate_name: Option<String>,
30 }
31
32 /// An variant in the enum set type.
33 struct EnumSetValue {
34     name: Ident,
35     variant_repr: u32,
36 }
37
38 /// Stores information about the enum set type.
39 #[allow(dead_code)]
40 struct EnumSetInfo {
41     name: Ident,
42     crate_name: Option<Ident>,
43     explicit_serde_repr: Option<Ident>,
44     has_signed_repr: bool,
45     has_large_repr: bool,
46     variants: Vec<EnumSetValue>,
47
48     max_discrim: u32,
49     cur_discrim: u32,
50     used_variant_names: HashSet<String>,
51     used_discriminators: HashSet<u32>,
52
53     no_ops: bool,
54     serialize_as_list: bool,
55     serialize_deny_unknown: bool,
56 }
57 impl EnumSetInfo {
58     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
59         EnumSetInfo {
60             name: input.ident.clone(),
61             crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
62             explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
63             has_signed_repr: false,
64             has_large_repr: false,
65             variants: Vec::new(),
66             max_discrim: 0,
67             cur_discrim: 0,
68             used_variant_names: HashSet::new(),
69             used_discriminators: HashSet::new(),
70             no_ops: attrs.no_ops,
71             serialize_as_list: attrs.serialize_as_list,
72             serialize_deny_unknown: attrs.serialize_deny_unknown
73         }
74     }
75
76     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
77         match repr {
78             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
79             "usize" | "u64" | "u128" => {
80                 self.has_large_repr = true;
81                 Ok(())
82             }
83             "i8" | "i16" | "i32" => {
84                 self.has_signed_repr = true;
85                 Ok(())
86             }
87             "isize" | "i64" | "i128" => {
88                 self.has_signed_repr = true;
89                 self.has_large_repr = true;
90                 Ok(())
91             }
92             _ => error(attr_span, "Unsupported repr.")
93         }
94     }
95     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
96         if self.used_variant_names.contains(&variant.ident.to_string()) {
97             error(variant.span(), "Duplicated variant name.")
98         } else if let Fields::Unit = variant.fields {
99             if let Some((_, expr)) = &variant.discriminant {
100                 let discriminant_fail_message = format!(
101                     "Enum set discriminants must be `u32`s.{}",
102                     if self.has_signed_repr || self.has_large_repr {
103                         format!(
104                             " ({} discrimiants are still unsupported with reprs that allow them.)",
105                             if self.has_large_repr {
106                                 "larger"
107                             } else if self.has_signed_repr {
108                                 "negative"
109                             } else {
110                                 "larger or negative"
111                             }
112                         )
113                     } else {
114                         String::new()
115                     },
116                 );
117                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
118                     match i.base10_parse() {
119                         Ok(val) => self.cur_discrim = val,
120                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
121                     }
122                 } else {
123                     error(variant.span(), &discriminant_fail_message)?;
124                 }
125             }
126
127             let discriminant = self.cur_discrim;
128             if discriminant >= 128 {
129                 let message = if self.variants.len() <= 127 {
130                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
131                 } else {
132                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
133                 };
134                 error(variant.span(), message)?;
135             }
136
137             if self.used_discriminators.contains(&discriminant) {
138                 error(variant.span(), "Duplicated enum discriminant.")?;
139             }
140
141             self.cur_discrim += 1;
142             if discriminant > self.max_discrim {
143                 self.max_discrim = discriminant;
144             }
145             self.variants.push(EnumSetValue {
146                 name: variant.ident.clone(),
147                 variant_repr: discriminant,
148             });
149             self.used_variant_names.insert(variant.ident.to_string());
150             self.used_discriminators.insert(discriminant);
151
152             Ok(())
153         } else {
154             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
155         }
156     }
157     fn validate(&self) -> Result<()> {
158         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
159             let is_overflowed = match explicit_serde_repr.to_string().as_str() {
160                 "u8" => self.max_discrim >= 8,
161                 "u16" => self.max_discrim >= 16,
162                 "u32" => self.max_discrim >= 32,
163                 "u64" => self.max_discrim >= 64,
164                 "u128" => self.max_discrim >= 128,
165                 _ => error(
166                     Span::call_site(),
167                     "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
168                 )?,
169             };
170             if is_overflowed {
171                 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
172             }
173         }
174         Ok(())
175     }
176
177     fn enumset_repr(&self) -> SynTokenStream {
178         if self.max_discrim <= 7 {
179             quote! { u8 }
180         } else if self.max_discrim <= 15 {
181             quote! { u16 }
182         } else if self.max_discrim <= 31 {
183             quote! { u32 }
184         } else if self.max_discrim <= 63 {
185             quote! { u64 }
186         } else if self.max_discrim <= 127 {
187             quote! { u128 }
188         } else {
189             panic!("max_variant > 127?")
190         }
191     }
192     #[cfg(feature = "serde")]
193     fn serde_repr(&self) -> SynTokenStream {
194         if let Some(serde_repr) = &self.explicit_serde_repr {
195             quote! { #serde_repr }
196         } else {
197             self.enumset_repr()
198         }
199     }
200
201     fn all_variants(&self) -> u128 {
202         let mut accum = 0u128;
203         for variant in &self.variants {
204             assert!(variant.variant_repr <= 127);
205             accum |= 1u128 << variant.variant_repr as u128;
206         }
207         accum
208     }
209 }
210
211 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
212     let name = &info.name;
213     let enumset = match &info.crate_name {
214         Some(crate_name) => quote!(::#crate_name),
215         None => quote!(::enumset),
216     };
217     let typed_enumset = quote!(#enumset::EnumSet<#name>);
218     let core = quote!(#enumset::__internal::core_export);
219
220     let repr = info.enumset_repr();
221     let all_variants = Literal::u128_unsuffixed(info.all_variants());
222
223     let ops = if info.no_ops {
224         quote! {}
225     } else {
226         quote! {
227             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
228                 type Output = #typed_enumset;
229                 fn sub(self, other: O) -> Self::Output {
230                     #enumset::EnumSet::only(self) - other.into()
231                 }
232             }
233             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
234                 type Output = #typed_enumset;
235                 fn bitand(self, other: O) -> Self::Output {
236                     #enumset::EnumSet::only(self) & other.into()
237                 }
238             }
239             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
240                 type Output = #typed_enumset;
241                 fn bitor(self, other: O) -> Self::Output {
242                     #enumset::EnumSet::only(self) | other.into()
243                 }
244             }
245             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
246                 type Output = #typed_enumset;
247                 fn bitxor(self, other: O) -> Self::Output {
248                     #enumset::EnumSet::only(self) ^ other.into()
249                 }
250             }
251             impl #core::ops::Not for #name {
252                 type Output = #typed_enumset;
253                 fn not(self) -> Self::Output {
254                     !#enumset::EnumSet::only(self)
255                 }
256             }
257             impl #core::cmp::PartialEq<#typed_enumset> for #name {
258                 fn eq(&self, other: &#typed_enumset) -> bool {
259                     #enumset::EnumSet::only(*self) == *other
260                 }
261             }
262         }
263     };
264
265
266     #[cfg(feature = "serde")]
267     let serde = quote!(#enumset::__internal::serde);
268
269     #[cfg(feature = "serde")]
270     let serde_ops = if info.serialize_as_list {
271         let expecting_str = format!("a list of {}", name);
272         quote! {
273             fn serialize<S: #serde::Serializer>(
274                 set: #enumset::EnumSet<#name>, ser: S,
275             ) -> #core::result::Result<S::Ok, S::Error> {
276                 use #serde::ser::SerializeSeq;
277                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
278                 for bit in set {
279                     seq.serialize_element(&bit)?;
280                 }
281                 seq.end()
282             }
283             fn deserialize<'de, D: #serde::Deserializer<'de>>(
284                 de: D,
285             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
286                 struct Visitor;
287                 impl <'de> #serde::de::Visitor<'de> for Visitor {
288                     type Value = #enumset::EnumSet<#name>;
289                     fn expecting(
290                         &self, formatter: &mut #core::fmt::Formatter,
291                     ) -> #core::fmt::Result {
292                         write!(formatter, #expecting_str)
293                     }
294                     fn visit_seq<A>(
295                         mut self, mut seq: A,
296                     ) -> #core::result::Result<Self::Value, A::Error> where
297                         A: #serde::de::SeqAccess<'de>
298                     {
299                         let mut accum = #enumset::EnumSet::<#name>::new();
300                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
301                             accum |= val;
302                         }
303                         #core::prelude::v1::Ok(accum)
304                     }
305                 }
306                 de.deserialize_seq(Visitor)
307             }
308         }
309     } else {
310         let serialize_repr = info.serde_repr();
311         let check_unknown = if info.serialize_deny_unknown {
312             quote! {
313                 if value & !#all_variants != 0 {
314                     use #serde::de::Error;
315                     return #core::prelude::v1::Err(
316                         D::Error::custom("enumset contains unknown bits")
317                     )
318                 }
319             }
320         } else {
321             quote! { }
322         };
323         quote! {
324             fn serialize<S: #serde::Serializer>(
325                 set: #enumset::EnumSet<#name>, ser: S,
326             ) -> #core::result::Result<S::Ok, S::Error> {
327                 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
328             }
329             fn deserialize<'de, D: #serde::Deserializer<'de>>(
330                 de: D,
331             ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
332                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
333                 #check_unknown
334                 #core::prelude::v1::Ok(#enumset::EnumSet {
335                     __enumset_underlying: (value & #all_variants) as #repr,
336                 })
337             }
338         }
339     };
340
341     #[cfg(not(feature = "serde"))]
342     let serde_ops = quote! { };
343
344     let is_uninhabited = info.variants.is_empty();
345     let is_zst = info.variants.len() == 1;
346     let into_impl = if is_uninhabited {
347         quote! {
348             fn enum_into_u32(self) -> u32 {
349                 panic!(concat!(stringify!(#name), " is uninhabited."))
350             }
351             unsafe fn enum_from_u32(val: u32) -> Self {
352                 panic!(concat!(stringify!(#name), " is uninhabited."))
353             }
354         }
355     } else if is_zst {
356         let variant = &info.variants[0].name;
357         quote! {
358             fn enum_into_u32(self) -> u32 {
359                 self as u32
360             }
361             unsafe fn enum_from_u32(val: u32) -> Self {
362                 #name::#variant
363             }
364         }
365     } else {
366         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
367         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
368
369         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
370             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
371         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
372             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
373
374         quote! {
375             fn enum_into_u32(self) -> u32 {
376                 self as u32
377             }
378             unsafe fn enum_from_u32(val: u32) -> Self {
379                 // We put these in const fields so they aren't generated even on -O0
380                 #(const #const_field: bool =
381                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
382                 match val {
383                     // Every valid variant value has an explicit branch. If they get optimized out,
384                     // great. If the representation has changed somehow, and they don't, oh well,
385                     // there's still no UB.
386                     #(#variant_value => #name::#variant_name,)*
387                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
388                     // still unreachable.
389                     #(x if #const_field => {
390                         let x = x as #int_type;
391                         *(&x as *const _ as *const #name)
392                     })*
393                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
394                     // transmute, but, oh well.
395                     _ => #core::hint::unreachable_unchecked(),
396                 }
397             }
398         }
399     };
400
401     let eq_impl = if is_uninhabited {
402         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
403     } else {
404         quote!((*self as u32) == (*other as u32))
405     };
406
407     quote! {
408         unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
409             type Repr = #repr;
410             const ALL_BITS: Self::Repr = #all_variants;
411             #into_impl
412             #serde_ops
413         }
414
415         unsafe impl #enumset::EnumSetType for #name { }
416
417         impl #core::cmp::PartialEq for #name {
418             fn eq(&self, other: &Self) -> bool {
419                 #eq_impl
420             }
421         }
422         impl #core::cmp::Eq for #name { }
423         impl #core::clone::Clone for #name {
424             fn clone(&self) -> Self {
425                 *self
426             }
427         }
428         impl #core::marker::Copy for #name { }
429
430         #ops
431     }
432 }
433
434 fn derive_enum_set_type_impl(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
435     if !input.generics.params.is_empty() {
436         error(
437             input.generics.span(),
438             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
439         )
440     } else if let Data::Enum(data) = &input.data {
441         let mut info = EnumSetInfo::new(&input, attrs);
442         for attr in &input.attrs {
443             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
444                 let meta: Ident = attr.parse_args()?;
445                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
446             }
447         }
448         for variant in &data.variants {
449             info.push_variant(variant)?;
450         }
451         info.validate()?;
452         Ok(enum_set_type_impl(info).into())
453     } else {
454         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
455     }
456 }
457
458 #[proc_macro_derive(EnumSetType, attributes(enumset))]
459 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
460     let input: DeriveInput = parse_macro_input!(input);
461     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
462         Ok(attrs) => attrs,
463         Err(e) => return e.write_errors().into(),
464     };
465     match derive_enum_set_type_impl(input, attrs) {
466         Ok(v) => v,
467         Err(e) => e.to_compile_error().into(),
468     }
469 }