]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Refactor enumset_derive in preperation for future changes.
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal};
8 use std::collections::HashSet;
9 use syn::{*, Result, Error};
10 use syn::export::Span;
11 use syn::spanned::Spanned;
12 use quote::*;
13
14 /// Helper function for emitting compile errors.
15 fn error<T>(span: Span, message: &str) -> Result<T> {
16     Err(Error::new(span, message))
17 }
18
19 /// Decodes the custom attributes for our custom derive.
20 #[derive(FromDeriveInput, Default)]
21 #[darling(attributes(enumset), default)]
22 struct EnumsetAttrs {
23     no_ops: bool,
24     serialize_as_list: bool,
25     serialize_deny_unknown: bool,
26     #[darling(default)]
27     serialize_repr: Option<String>,
28 }
29
30 /// An variant in the enum set type.
31 struct EnumSetValue {
32     name: Ident,
33     variant_repr: u32,
34 }
35
36 /// Stores information about the enum set type.
37 #[allow(dead_code)]
38 struct EnumSetInfo {
39     name: Ident,
40     explicit_repr: Option<Ident>,
41     explicit_serde_repr: Option<Ident>,
42     has_signed_repr: bool,
43     has_large_repr: bool,
44     variants: Vec<EnumSetValue>,
45
46     max_discrim: u32,
47     cur_discrim: u32,
48     used_variant_names: HashSet<String>,
49     used_discriminators: HashSet<u32>,
50
51     no_ops: bool,
52     serialize_as_list: bool,
53     serialize_deny_unknown: bool,
54 }
55 impl EnumSetInfo {
56     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
57         EnumSetInfo {
58             name: input.ident.clone(),
59             explicit_repr: None,
60             explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
61             has_signed_repr: false,
62             has_large_repr: false,
63             variants: Vec::new(),
64             max_discrim: 0,
65             cur_discrim: 0,
66             used_variant_names: HashSet::new(),
67             used_discriminators: HashSet::new(),
68             no_ops: attrs.no_ops,
69             serialize_as_list: attrs.serialize_as_list,
70             serialize_deny_unknown: attrs.serialize_deny_unknown
71         }
72     }
73
74     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
75         if self.explicit_repr.is_some() {
76             error(attr_span, "Cannot duplicate #[repr(...)] annotations.")
77         } else {
78             self.explicit_repr = Some(Ident::new(match repr {
79                 "Rust" | "C" => return Ok(()), // We assume default repr in these cases.
80                 "u8" | "u16" | "u32" => repr,
81                 "usize" | "u64" | "u128" => {
82                     self.has_large_repr = true;
83                     repr
84                 }
85                 "i8" | "i16" | "i32" => {
86                     self.has_signed_repr = true;
87                     repr
88                 }
89                 "isize" | "i64" | "i128" => {
90                     self.has_signed_repr = true;
91                     self.has_large_repr = true;
92                     repr
93                 }
94                 _ => return error(attr_span, "Unsupported repr.")
95             }, Span::call_site()));
96             Ok(())
97         }
98     }
99     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
100         if self.used_variant_names.contains(&variant.ident.to_string()) {
101             error(variant.span(), "Duplicated variant name.")
102         } else if let Fields::Unit = variant.fields {
103             if let Some((_, expr)) = &variant.discriminant {
104                 let discriminant_fail_message = format!(
105                     "Enum set discriminants must be `u32`s.{}",
106                     if self.has_signed_repr || self.has_large_repr {
107                         format!(
108                             " ({} discrimiants are still unsupported with reprs that allow them.)",
109                             if self.has_large_repr {
110                                 "larger"
111                             } else if self.has_signed_repr {
112                                 "negative"
113                             } else {
114                                 "larger or negative"
115                             }
116                         )
117                     } else {
118                         String::new()
119                     },
120                 );
121                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
122                     match i.base10_parse() {
123                         Ok(val) => self.cur_discrim = val,
124                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
125                     }
126                 } else {
127                     error(variant.span(), &discriminant_fail_message)?;
128                 }
129             }
130
131             let discriminant = self.cur_discrim;
132             if discriminant >= 128 {
133                 let message = if self.variants.len() <= 127 {
134                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
135                 } else {
136                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
137                 };
138                 error(variant.span(), message)?;
139             }
140
141             if self.used_discriminators.contains(&discriminant) {
142                 error(variant.span(), "Duplicated enum discriminant.")?;
143             }
144
145             self.cur_discrim += 1;
146             if discriminant > self.max_discrim {
147                 self.max_discrim = discriminant;
148             }
149             self.variants.push(EnumSetValue {
150                 name: variant.ident.clone(),
151                 variant_repr: discriminant,
152             });
153             self.used_variant_names.insert(variant.ident.to_string());
154             self.used_discriminators.insert(discriminant);
155
156             Ok(())
157         } else {
158             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
159         }
160     }
161     fn validate(&self) -> Result<()> {
162         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
163             let is_overflowed = match explicit_serde_repr.to_string().as_str() {
164                 "u8" => self.max_discrim >= 8,
165                 "u16" => self.max_discrim >= 16,
166                 "u32" => self.max_discrim >= 32,
167                 "u64" => self.max_discrim >= 64,
168                 "u128" => self.max_discrim >= 128,
169                 _ => error(
170                     Span::call_site(),
171                     "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
172                 )?,
173             };
174             if is_overflowed {
175                 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
176             }
177         }
178         Ok(())
179     }
180
181     fn enum_repr(&self) -> SynTokenStream {
182         if let Some(explicit_repr) = &self.explicit_repr {
183             quote! { #explicit_repr }
184         } else if self.max_discrim < 0x100 {
185             quote! { u8 }
186         } else if self.max_discrim < 0x10000 {
187             quote! { u16 }
188         } else {
189             quote! { u32 }
190         }
191     }
192     fn enumset_repr(&self) -> SynTokenStream {
193         if self.max_discrim <= 7 {
194             quote! { u8 }
195         } else if self.max_discrim <= 15 {
196             quote! { u16 }
197         } else if self.max_discrim <= 31 {
198             quote! { u32 }
199         } else if self.max_discrim <= 63 {
200             quote! { u64 }
201         } else if self.max_discrim <= 127 {
202             quote! { u128 }
203         } else {
204             panic!("max_variant > 127?")
205         }
206     }
207     #[cfg(feature = "serde")]
208     fn serde_repr(&self) -> SynTokenStream {
209         if let Some(serde_repr) = &self.explicit_serde_repr {
210             quote! { #serde_repr }
211         } else {
212             self.enumset_repr()
213         }
214     }
215
216     fn all_variants(&self) -> u128 {
217         let mut accum = 0u128;
218         for variant in &self.variants {
219             assert!(variant.variant_repr <= 127);
220             accum |= 1u128 << variant.variant_repr as u128;
221         }
222         accum
223     }
224 }
225
226 fn enum_set_type_impl(
227     info: EnumSetInfo,
228 //    name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
229 //    enum_repr: Ident,
230 ) -> SynTokenStream {
231     let name = &info.name;
232     let typed_enumset = quote!(::enumset::EnumSet<#name>);
233     let core = quote!(::enumset::internal::core_export);
234
235     let repr = info.enumset_repr();
236     let enum_repr = info.enum_repr();
237     let all_variants = Literal::u128_unsuffixed(info.all_variants());
238
239     let ops = if info.no_ops {
240         quote! {}
241     } else {
242         quote! {
243             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
244                 type Output = #typed_enumset;
245                 fn sub(self, other: O) -> Self::Output {
246                     ::enumset::EnumSet::only(self) - other.into()
247                 }
248             }
249             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
250                 type Output = #typed_enumset;
251                 fn bitand(self, other: O) -> Self::Output {
252                     ::enumset::EnumSet::only(self) & other.into()
253                 }
254             }
255             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
256                 type Output = #typed_enumset;
257                 fn bitor(self, other: O) -> Self::Output {
258                     ::enumset::EnumSet::only(self) | other.into()
259                 }
260             }
261             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
262                 type Output = #typed_enumset;
263                 fn bitxor(self, other: O) -> Self::Output {
264                     ::enumset::EnumSet::only(self) ^ other.into()
265                 }
266             }
267             impl #core::ops::Not for #name {
268                 type Output = #typed_enumset;
269                 fn not(self) -> Self::Output {
270                     !::enumset::EnumSet::only(self)
271                 }
272             }
273             impl #core::cmp::PartialEq<#typed_enumset> for #name {
274                 fn eq(&self, other: &#typed_enumset) -> bool {
275                     ::enumset::EnumSet::only(*self) == *other
276                 }
277             }
278         }
279     };
280
281
282     #[cfg(feature = "serde")]
283     let serde = quote!(::enumset::internal::serde);
284
285     #[cfg(feature = "serde")]
286     let serde_ops = if info.serialize_as_list {
287         let expecting_str = format!("a list of {}", name);
288         quote! {
289             fn serialize<S: #serde::Serializer>(
290                 set: ::enumset::EnumSet<#name>, ser: S,
291             ) -> #core::result::Result<S::Ok, S::Error> {
292                 use #serde::ser::SerializeSeq;
293                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
294                 for bit in set {
295                     seq.serialize_element(&bit)?;
296                 }
297                 seq.end()
298             }
299             fn deserialize<'de, D: #serde::Deserializer<'de>>(
300                 de: D,
301             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
302                 struct Visitor;
303                 impl <'de> #serde::de::Visitor<'de> for Visitor {
304                     type Value = ::enumset::EnumSet<#name>;
305                     fn expecting(
306                         &self, formatter: &mut #core::fmt::Formatter,
307                     ) -> #core::fmt::Result {
308                         write!(formatter, #expecting_str)
309                     }
310                     fn visit_seq<A>(
311                         mut self, mut seq: A,
312                     ) -> #core::result::Result<Self::Value, A::Error> where
313                         A: #serde::de::SeqAccess<'de>
314                     {
315                         let mut accum = ::enumset::EnumSet::<#name>::new();
316                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
317                             accum |= val;
318                         }
319                         #core::prelude::v1::Ok(accum)
320                     }
321                 }
322                 de.deserialize_seq(Visitor)
323             }
324         }
325     } else {
326         let serialize_repr = info.serde_repr();
327         let check_unknown = if info.serialize_deny_unknown {
328             quote! {
329                 if value & !#all_variants != 0 {
330                     use #serde::de::Error;
331                     return #core::prelude::v1::Err(
332                         D::Error::custom("enumset contains unknown bits")
333                     )
334                 }
335             }
336         } else {
337             quote! { }
338         };
339         quote! {
340             fn serialize<S: #serde::Serializer>(
341                 set: ::enumset::EnumSet<#name>, ser: S,
342             ) -> #core::result::Result<S::Ok, S::Error> {
343                 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
344             }
345             fn deserialize<'de, D: #serde::Deserializer<'de>>(
346                 de: D,
347             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
348                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
349                 #check_unknown
350                 #core::prelude::v1::Ok(::enumset::EnumSet {
351                     __enumset_underlying: (value & #all_variants) as #repr,
352                 })
353             }
354         }
355     };
356
357     #[cfg(not(feature = "serde"))]
358     let serde_ops = quote! { };
359
360     let is_uninhabited = info.variants.is_empty();
361     let is_zst = info.variants.len() == 1;
362     let into_impl = if is_uninhabited {
363         quote! {
364             fn enum_into_u32(self) -> u32 {
365                 panic!(concat!(stringify!(#name), " is uninhabited."))
366             }
367             unsafe fn enum_from_u32(val: u32) -> Self {
368                 panic!(concat!(stringify!(#name), " is uninhabited."))
369             }
370         }
371     } else if is_zst {
372         let variant = &info.variants[0].name;
373         quote! {
374             fn enum_into_u32(self) -> u32 {
375                 self as u32
376             }
377             unsafe fn enum_from_u32(val: u32) -> Self {
378                 #name::#variant
379             }
380         }
381     } else {
382         quote! {
383             fn enum_into_u32(self) -> u32 {
384                 self as u32
385             }
386             unsafe fn enum_from_u32(val: u32) -> Self {
387                 #core::mem::transmute(val as #enum_repr)
388             }
389         }
390     };
391
392     let eq_impl = if is_uninhabited {
393         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
394     } else {
395         quote!((*self as u32) == (*other as u32))
396     };
397
398     quote! {
399         unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
400             type Repr = #repr;
401             const ALL_BITS: Self::Repr = #all_variants;
402             #into_impl
403             #serde_ops
404         }
405
406         unsafe impl ::enumset::EnumSetType for #name { }
407
408         impl #core::cmp::PartialEq for #name {
409             fn eq(&self, other: &Self) -> bool {
410                 #eq_impl
411             }
412         }
413         impl #core::cmp::Eq for #name { }
414         impl #core::clone::Clone for #name {
415             fn clone(&self) -> Self {
416                 *self
417             }
418         }
419         impl #core::marker::Copy for #name { }
420
421         #ops
422     }
423 }
424
425 fn derive_enum_set_type_impl(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
426     if !input.generics.params.is_empty() {
427         error(
428             input.generics.span(),
429             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
430         )
431     } else if let Data::Enum(data) = &input.data {
432         let mut info = EnumSetInfo::new(&input, attrs);
433         for attr in &input.attrs {
434             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
435                 let meta: Ident = attr.parse_args()?;
436                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
437             }
438         }
439         for variant in &data.variants {
440             info.push_variant(variant)?;
441         }
442         info.validate()?;
443         Ok(enum_set_type_impl(info).into())
444     } else {
445         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
446     }
447 }
448
449 #[proc_macro_derive(EnumSetType, attributes(enumset))]
450 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
451     let input: DeriveInput = parse_macro_input!(input);
452     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
453         Ok(attrs) => attrs,
454         Err(e) => return e.write_errors().into(),
455     };
456     match derive_enum_set_type_impl(input, attrs) {
457         Ok(v) => v,
458         Err(e) => e.to_compile_error().into(),
459     }
460 }