]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Make enumset_derive more resiliant against potential future changes to the representa...
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2
3 extern crate proc_macro;
4
5 use darling::*;
6 use proc_macro::TokenStream;
7 use proc_macro2::{TokenStream as SynTokenStream, Literal};
8 use std::collections::HashSet;
9 use syn::{*, Result, Error};
10 use syn::export::Span;
11 use syn::spanned::Spanned;
12 use quote::*;
13
14 /// Helper function for emitting compile errors.
15 fn error<T>(span: Span, message: &str) -> Result<T> {
16     Err(Error::new(span, message))
17 }
18
19 /// Decodes the custom attributes for our custom derive.
20 #[derive(FromDeriveInput, Default)]
21 #[darling(attributes(enumset), default)]
22 struct EnumsetAttrs {
23     no_ops: bool,
24     serialize_as_list: bool,
25     serialize_deny_unknown: bool,
26     #[darling(default)]
27     serialize_repr: Option<String>,
28 }
29
30 /// An variant in the enum set type.
31 struct EnumSetValue {
32     name: Ident,
33     variant_repr: u32,
34 }
35
36 /// Stores information about the enum set type.
37 #[allow(dead_code)]
38 struct EnumSetInfo {
39     name: Ident,
40     explicit_serde_repr: Option<Ident>,
41     has_signed_repr: bool,
42     has_large_repr: bool,
43     variants: Vec<EnumSetValue>,
44
45     max_discrim: u32,
46     cur_discrim: u32,
47     used_variant_names: HashSet<String>,
48     used_discriminators: HashSet<u32>,
49
50     no_ops: bool,
51     serialize_as_list: bool,
52     serialize_deny_unknown: bool,
53 }
54 impl EnumSetInfo {
55     fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
56         EnumSetInfo {
57             name: input.ident.clone(),
58             explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
59             has_signed_repr: false,
60             has_large_repr: false,
61             variants: Vec::new(),
62             max_discrim: 0,
63             cur_discrim: 0,
64             used_variant_names: HashSet::new(),
65             used_discriminators: HashSet::new(),
66             no_ops: attrs.no_ops,
67             serialize_as_list: attrs.serialize_as_list,
68             serialize_deny_unknown: attrs.serialize_deny_unknown
69         }
70     }
71
72     fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
73         match repr {
74             "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
75             "usize" | "u64" | "u128" => {
76                 self.has_large_repr = true;
77                 Ok(())
78             }
79             "i8" | "i16" | "i32" => {
80                 self.has_signed_repr = true;
81                 Ok(())
82             }
83             "isize" | "i64" | "i128" => {
84                 self.has_signed_repr = true;
85                 self.has_large_repr = true;
86                 Ok(())
87             }
88             _ => error(attr_span, "Unsupported repr.")
89         }
90     }
91     fn push_variant(&mut self, variant: &Variant) -> Result<()> {
92         if self.used_variant_names.contains(&variant.ident.to_string()) {
93             error(variant.span(), "Duplicated variant name.")
94         } else if let Fields::Unit = variant.fields {
95             if let Some((_, expr)) = &variant.discriminant {
96                 let discriminant_fail_message = format!(
97                     "Enum set discriminants must be `u32`s.{}",
98                     if self.has_signed_repr || self.has_large_repr {
99                         format!(
100                             " ({} discrimiants are still unsupported with reprs that allow them.)",
101                             if self.has_large_repr {
102                                 "larger"
103                             } else if self.has_signed_repr {
104                                 "negative"
105                             } else {
106                                 "larger or negative"
107                             }
108                         )
109                     } else {
110                         String::new()
111                     },
112                 );
113                 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
114                     match i.base10_parse() {
115                         Ok(val) => self.cur_discrim = val,
116                         Err(_) => error(expr.span(), &discriminant_fail_message)?,
117                     }
118                 } else {
119                     error(variant.span(), &discriminant_fail_message)?;
120                 }
121             }
122
123             let discriminant = self.cur_discrim;
124             if discriminant >= 128 {
125                 let message = if self.variants.len() <= 127 {
126                     "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
127                 } else {
128                     "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
129                 };
130                 error(variant.span(), message)?;
131             }
132
133             if self.used_discriminators.contains(&discriminant) {
134                 error(variant.span(), "Duplicated enum discriminant.")?;
135             }
136
137             self.cur_discrim += 1;
138             if discriminant > self.max_discrim {
139                 self.max_discrim = discriminant;
140             }
141             self.variants.push(EnumSetValue {
142                 name: variant.ident.clone(),
143                 variant_repr: discriminant,
144             });
145             self.used_variant_names.insert(variant.ident.to_string());
146             self.used_discriminators.insert(discriminant);
147
148             Ok(())
149         } else {
150             error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
151         }
152     }
153     fn validate(&self) -> Result<()> {
154         if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
155             let is_overflowed = match explicit_serde_repr.to_string().as_str() {
156                 "u8" => self.max_discrim >= 8,
157                 "u16" => self.max_discrim >= 16,
158                 "u32" => self.max_discrim >= 32,
159                 "u64" => self.max_discrim >= 64,
160                 "u128" => self.max_discrim >= 128,
161                 _ => error(
162                     Span::call_site(),
163                     "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
164                 )?,
165             };
166             if is_overflowed {
167                 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
168             }
169         }
170         Ok(())
171     }
172
173     fn enumset_repr(&self) -> SynTokenStream {
174         if self.max_discrim <= 7 {
175             quote! { u8 }
176         } else if self.max_discrim <= 15 {
177             quote! { u16 }
178         } else if self.max_discrim <= 31 {
179             quote! { u32 }
180         } else if self.max_discrim <= 63 {
181             quote! { u64 }
182         } else if self.max_discrim <= 127 {
183             quote! { u128 }
184         } else {
185             panic!("max_variant > 127?")
186         }
187     }
188     #[cfg(feature = "serde")]
189     fn serde_repr(&self) -> SynTokenStream {
190         if let Some(serde_repr) = &self.explicit_serde_repr {
191             quote! { #serde_repr }
192         } else {
193             self.enumset_repr()
194         }
195     }
196
197     fn all_variants(&self) -> u128 {
198         let mut accum = 0u128;
199         for variant in &self.variants {
200             assert!(variant.variant_repr <= 127);
201             accum |= 1u128 << variant.variant_repr as u128;
202         }
203         accum
204     }
205 }
206
207 fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
208     let name = &info.name;
209     let typed_enumset = quote!(::enumset::EnumSet<#name>);
210     let core = quote!(::enumset::internal::core_export);
211
212     let repr = info.enumset_repr();
213     let all_variants = Literal::u128_unsuffixed(info.all_variants());
214
215     let ops = if info.no_ops {
216         quote! {}
217     } else {
218         quote! {
219             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
220                 type Output = #typed_enumset;
221                 fn sub(self, other: O) -> Self::Output {
222                     ::enumset::EnumSet::only(self) - other.into()
223                 }
224             }
225             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
226                 type Output = #typed_enumset;
227                 fn bitand(self, other: O) -> Self::Output {
228                     ::enumset::EnumSet::only(self) & other.into()
229                 }
230             }
231             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
232                 type Output = #typed_enumset;
233                 fn bitor(self, other: O) -> Self::Output {
234                     ::enumset::EnumSet::only(self) | other.into()
235                 }
236             }
237             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
238                 type Output = #typed_enumset;
239                 fn bitxor(self, other: O) -> Self::Output {
240                     ::enumset::EnumSet::only(self) ^ other.into()
241                 }
242             }
243             impl #core::ops::Not for #name {
244                 type Output = #typed_enumset;
245                 fn not(self) -> Self::Output {
246                     !::enumset::EnumSet::only(self)
247                 }
248             }
249             impl #core::cmp::PartialEq<#typed_enumset> for #name {
250                 fn eq(&self, other: &#typed_enumset) -> bool {
251                     ::enumset::EnumSet::only(*self) == *other
252                 }
253             }
254         }
255     };
256
257
258     #[cfg(feature = "serde")]
259     let serde = quote!(::enumset::internal::serde);
260
261     #[cfg(feature = "serde")]
262     let serde_ops = if info.serialize_as_list {
263         let expecting_str = format!("a list of {}", name);
264         quote! {
265             fn serialize<S: #serde::Serializer>(
266                 set: ::enumset::EnumSet<#name>, ser: S,
267             ) -> #core::result::Result<S::Ok, S::Error> {
268                 use #serde::ser::SerializeSeq;
269                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
270                 for bit in set {
271                     seq.serialize_element(&bit)?;
272                 }
273                 seq.end()
274             }
275             fn deserialize<'de, D: #serde::Deserializer<'de>>(
276                 de: D,
277             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
278                 struct Visitor;
279                 impl <'de> #serde::de::Visitor<'de> for Visitor {
280                     type Value = ::enumset::EnumSet<#name>;
281                     fn expecting(
282                         &self, formatter: &mut #core::fmt::Formatter,
283                     ) -> #core::fmt::Result {
284                         write!(formatter, #expecting_str)
285                     }
286                     fn visit_seq<A>(
287                         mut self, mut seq: A,
288                     ) -> #core::result::Result<Self::Value, A::Error> where
289                         A: #serde::de::SeqAccess<'de>
290                     {
291                         let mut accum = ::enumset::EnumSet::<#name>::new();
292                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
293                             accum |= val;
294                         }
295                         #core::prelude::v1::Ok(accum)
296                     }
297                 }
298                 de.deserialize_seq(Visitor)
299             }
300         }
301     } else {
302         let serialize_repr = info.serde_repr();
303         let check_unknown = if info.serialize_deny_unknown {
304             quote! {
305                 if value & !#all_variants != 0 {
306                     use #serde::de::Error;
307                     return #core::prelude::v1::Err(
308                         D::Error::custom("enumset contains unknown bits")
309                     )
310                 }
311             }
312         } else {
313             quote! { }
314         };
315         quote! {
316             fn serialize<S: #serde::Serializer>(
317                 set: ::enumset::EnumSet<#name>, ser: S,
318             ) -> #core::result::Result<S::Ok, S::Error> {
319                 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
320             }
321             fn deserialize<'de, D: #serde::Deserializer<'de>>(
322                 de: D,
323             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
324                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
325                 #check_unknown
326                 #core::prelude::v1::Ok(::enumset::EnumSet {
327                     __enumset_underlying: (value & #all_variants) as #repr,
328                 })
329             }
330         }
331     };
332
333     #[cfg(not(feature = "serde"))]
334     let serde_ops = quote! { };
335
336     let is_uninhabited = info.variants.is_empty();
337     let is_zst = info.variants.len() == 1;
338     let into_impl = if is_uninhabited {
339         quote! {
340             fn enum_into_u32(self) -> u32 {
341                 panic!(concat!(stringify!(#name), " is uninhabited."))
342             }
343             unsafe fn enum_from_u32(val: u32) -> Self {
344                 panic!(concat!(stringify!(#name), " is uninhabited."))
345             }
346         }
347     } else if is_zst {
348         let variant = &info.variants[0].name;
349         quote! {
350             fn enum_into_u32(self) -> u32 {
351                 self as u32
352             }
353             unsafe fn enum_from_u32(val: u32) -> Self {
354                 #name::#variant
355             }
356         }
357     } else {
358         let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
359         let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
360
361         let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
362             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
363         let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
364             .iter().map(|x| Ident::new(x, Span::call_site())).collect();
365
366         quote! {
367             fn enum_into_u32(self) -> u32 {
368                 self as u32
369             }
370             unsafe fn enum_from_u32(val: u32) -> Self {
371                 // We put these in const fields so they aren't generated even on -O0
372                 #(const #const_field: bool =
373                     #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
374                 match val {
375                     // Every valid variant value has an explicit branch. If they get optimized out,
376                     // great. If the representation has changed somehow, and they don't, oh well,
377                     // there's still no UB.
378                     #(#variant_value => #name::#variant_name,)*
379                     // Helps hint to the LLVM that this is a transmute. Note that this branch is
380                     // still unreachable.
381                     #(x if #const_field => {
382                         let x = x as #int_type;
383                         *(&x as *const _ as *const #name)
384                     })*
385                     // Default case. Sometimes causes LLVM to generate a table instead of a simple
386                     // transmute, but, oh well.
387                     _ => #core::hint::unreachable_unchecked(),
388                 }
389             }
390         }
391     };
392
393     let eq_impl = if is_uninhabited {
394         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
395     } else {
396         quote!((*self as u32) == (*other as u32))
397     };
398
399     quote! {
400         unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
401             type Repr = #repr;
402             const ALL_BITS: Self::Repr = #all_variants;
403             #into_impl
404             #serde_ops
405         }
406
407         unsafe impl ::enumset::EnumSetType for #name { }
408
409         impl #core::cmp::PartialEq for #name {
410             fn eq(&self, other: &Self) -> bool {
411                 #eq_impl
412             }
413         }
414         impl #core::cmp::Eq for #name { }
415         impl #core::clone::Clone for #name {
416             fn clone(&self) -> Self {
417                 *self
418             }
419         }
420         impl #core::marker::Copy for #name { }
421
422         #ops
423     }
424 }
425
426 fn derive_enum_set_type_impl(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
427     if !input.generics.params.is_empty() {
428         error(
429             input.generics.span(),
430             "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
431         )
432     } else if let Data::Enum(data) = &input.data {
433         let mut info = EnumSetInfo::new(&input, attrs);
434         for attr in &input.attrs {
435             if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
436                 let meta: Ident = attr.parse_args()?;
437                 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
438             }
439         }
440         for variant in &data.variants {
441             info.push_variant(variant)?;
442         }
443         info.validate()?;
444         Ok(enum_set_type_impl(info).into())
445     } else {
446         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
447     }
448 }
449
450 #[proc_macro_derive(EnumSetType, attributes(enumset))]
451 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
452     let input: DeriveInput = parse_macro_input!(input);
453     let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
454         Ok(attrs) => attrs,
455         Err(e) => return e.write_errors().into(),
456     };
457     match derive_enum_set_type_impl(input, attrs) {
458         Ok(v) => v,
459         Err(e) => e.to_compile_error().into(),
460     }
461 }