]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Add UI tests.
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
3
4 // TODO: Read #[repr(...)] attributes.
5
6 extern crate proc_macro;
7
8 use darling::*;
9 use proc_macro::TokenStream;
10 use proc_macro2::{TokenStream as SynTokenStream, Literal};
11 use syn::{*, Result, Error};
12 use syn::export::Span;
13 use syn::spanned::Spanned;
14 use quote::*;
15
16 fn error<T>(span: Span, message: &str) -> Result<T> {
17     Err(Error::new(span, message))
18 }
19
20 fn enum_set_type_impl(
21     name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
22     enum_repr: Ident,
23 ) -> SynTokenStream {
24     let is_uninhabited = variants.is_empty();
25     let is_zst = variants.len() == 1;
26
27     let typed_enumset = quote!(::enumset::EnumSet<#name>);
28     let core = quote!(::enumset::internal::core_export);
29     #[cfg(feature = "serde")]
30     let serde = quote!(::enumset::internal::serde);
31
32     // proc_macro2 does not support creating u128 literals.
33     let all_variants = Literal::u128_unsuffixed(all_variants);
34
35     let ops = if attrs.no_ops {
36         quote! {}
37     } else {
38         quote! {
39             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
40                 type Output = #typed_enumset;
41                 fn sub(self, other: O) -> Self::Output {
42                     ::enumset::EnumSet::only(self) - other.into()
43                 }
44             }
45             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
46                 type Output = #typed_enumset;
47                 fn bitand(self, other: O) -> Self::Output {
48                     ::enumset::EnumSet::only(self) & other.into()
49                 }
50             }
51             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
52                 type Output = #typed_enumset;
53                 fn bitor(self, other: O) -> Self::Output {
54                     ::enumset::EnumSet::only(self) | other.into()
55                 }
56             }
57             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
58                 type Output = #typed_enumset;
59                 fn bitxor(self, other: O) -> Self::Output {
60                     ::enumset::EnumSet::only(self) ^ other.into()
61                 }
62             }
63             impl #core::ops::Not for #name {
64                 type Output = #typed_enumset;
65                 fn not(self) -> Self::Output {
66                     !::enumset::EnumSet::only(self)
67                 }
68             }
69             impl #core::cmp::PartialEq<#typed_enumset> for #name {
70                 fn eq(&self, other: &#typed_enumset) -> bool {
71                     ::enumset::EnumSet::only(*self) == *other
72                 }
73             }
74         }
75     };
76
77     #[cfg(feature = "serde")]
78     let serde_ops = if attrs.serialize_as_list {
79         let expecting_str = format!("a list of {}", name);
80         quote! {
81             fn serialize<S: #serde::Serializer>(
82                 set: ::enumset::EnumSet<#name>, ser: S,
83             ) -> #core::result::Result<S::Ok, S::Error> {
84                 use #serde::ser::SerializeSeq;
85                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
86                 for bit in set {
87                     seq.serialize_element(&bit)?;
88                 }
89                 seq.end()
90             }
91             fn deserialize<'de, D: #serde::Deserializer<'de>>(
92                 de: D,
93             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
94                 struct Visitor;
95                 impl <'de> #serde::de::Visitor<'de> for Visitor {
96                     type Value = ::enumset::EnumSet<#name>;
97                     fn expecting(
98                         &self, formatter: &mut #core::fmt::Formatter,
99                     ) -> #core::fmt::Result {
100                         write!(formatter, #expecting_str)
101                     }
102                     fn visit_seq<A>(
103                         mut self, mut seq: A,
104                     ) -> #core::result::Result<Self::Value, A::Error> where
105                         A: #serde::de::SeqAccess<'de>
106                     {
107                         let mut accum = ::enumset::EnumSet::<#name>::new();
108                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
109                             accum |= val;
110                         }
111                         #core::prelude::v1::Ok(accum)
112                     }
113                 }
114                 de.deserialize_seq(Visitor)
115             }
116         }
117     } else {
118         let serialize_repr = attrs.serialize_repr.as_ref()
119             .map(|x| Ident::new(&x, Span::call_site()))
120             .unwrap_or(repr.clone());
121         let check_unknown = if attrs.serialize_deny_unknown {
122             quote! {
123                 if value & !#all_variants != 0 {
124                     use #serde::de::Error;
125                     return #core::prelude::v1::Err(
126                         D::Error::custom("enumset contains unknown bits")
127                     )
128                 }
129             }
130         } else {
131             quote! { }
132         };
133         quote! {
134             fn serialize<S: #serde::Serializer>(
135                 set: ::enumset::EnumSet<#name>, ser: S,
136             ) -> #core::result::Result<S::Ok, S::Error> {
137                 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
138             }
139             fn deserialize<'de, D: #serde::Deserializer<'de>>(
140                 de: D,
141             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
142                 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
143                 #check_unknown
144                 #core::prelude::v1::Ok(::enumset::EnumSet {
145                     __enumset_underlying: (value & #all_variants) as #repr,
146                 })
147             }
148         }
149     };
150
151     #[cfg(not(feature = "serde"))]
152     let serde_ops = quote! { };
153
154     let into_impl = if is_uninhabited {
155         quote! {
156             fn enum_into_u32(self) -> u32 {
157                 panic!(concat!(stringify!(#name), " is uninhabited."))
158             }
159             unsafe fn enum_from_u32(val: u32) -> Self {
160                 panic!(concat!(stringify!(#name), " is uninhabited."))
161             }
162         }
163     } else if is_zst {
164         let variant = &variants[0];
165         quote! {
166             fn enum_into_u32(self) -> u32 {
167                 self as u32
168             }
169             unsafe fn enum_from_u32(val: u32) -> Self {
170                 #name::#variant
171             }
172         }
173     } else {
174         quote! {
175             fn enum_into_u32(self) -> u32 {
176                 self as u32
177             }
178             unsafe fn enum_from_u32(val: u32) -> Self {
179                 #core::mem::transmute(val as #enum_repr)
180             }
181         }
182     };
183
184     let eq_impl = if is_uninhabited {
185         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
186     } else {
187         quote!((*self as u32) == (*other as u32))
188     };
189
190     quote! {
191         unsafe impl ::enumset::internal::EnumSetTypePrivate for #name {
192             type Repr = #repr;
193             const ALL_BITS: Self::Repr = #all_variants;
194             #into_impl
195             #serde_ops
196         }
197
198         unsafe impl ::enumset::EnumSetType for #name { }
199
200         impl #core::cmp::PartialEq for #name {
201             fn eq(&self, other: &Self) -> bool {
202                 #eq_impl
203             }
204         }
205         impl #core::cmp::Eq for #name { }
206         impl #core::clone::Clone for #name {
207             fn clone(&self) -> Self {
208                 *self
209             }
210         }
211         impl #core::marker::Copy for #name { }
212
213         #ops
214     }
215 }
216
217 #[derive(FromDeriveInput, Default)]
218 #[darling(attributes(enumset), default)]
219 struct EnumsetAttrs {
220     no_ops: bool,
221     serialize_as_list: bool,
222     serialize_deny_unknown: bool,
223     #[darling(default)]
224     serialize_repr: Option<String>,
225 }
226
227 fn derive_enum_set_type_impl(input: DeriveInput) -> Result<TokenStream> {
228     if let Data::Enum(data) = &input.data {
229         if !input.generics.params.is_empty() {
230             error(
231                 input.generics.span(),
232                 "`#[derive(EnumSetType)]` cannot be used on enums with type parameters."
233             )
234         } else {
235             let mut all_variants = 0u128;
236             let mut max_variant = 0u32;
237             let mut current_variant = 0u32;
238             let mut has_manual_discriminant = false;
239             let mut variants = Vec::new();
240
241             for variant in &data.variants {
242                 if let Fields::Unit = variant.fields {
243                     if let Some((_, expr)) = &variant.discriminant {
244                         if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
245                             current_variant = match i.base10_parse() {
246                                 Ok(val) => val,
247                                 Err(_) => error(
248                                     expr.span(), "Enum set discriminants must be `u32`s.",
249                                 )?,
250                             };
251                             has_manual_discriminant = true;
252                         } else {
253                             error(
254                                 variant.span(), "Enum set discriminants must be `u32`s."
255                             )?;
256                         }
257                     }
258
259                     if current_variant >= 128 {
260                         let message = if has_manual_discriminant {
261                             "`#[derive(EnumSetType)]` currently only supports \
262                              enum discriminants up to 127."
263                         } else {
264                             "`#[derive(EnumSetType)]` currently only supports \
265                              enums up to 128 variants."
266                         };
267                         error(variant.span(), message)?;
268                     }
269
270                     if all_variants & (1 << current_variant as u128) != 0 {
271                         error(
272                             variant.span(),
273                             &format!("Duplicate enum discriminant: {}", current_variant)
274                         )?;
275                     }
276                     all_variants |= 1 << current_variant as u128;
277                     if current_variant > max_variant {
278                         max_variant = current_variant
279                     }
280
281                     variants.push(variant.ident.clone());
282                     current_variant += 1;
283                 } else {
284                     error(
285                         variant.span(),
286                         "`#[derive(EnumSetType)]` can only be used on fieldless enums."
287                     )?;
288                 }
289             }
290
291             let repr = Ident::new(if max_variant <= 7 {
292                 "u8"
293             } else if max_variant <= 15 {
294                 "u16"
295             } else if max_variant <= 31 {
296                 "u32"
297             } else if max_variant <= 63 {
298                 "u64"
299             } else if max_variant <= 127 {
300                 "u128"
301             } else {
302                 panic!("max_variant > 127?")
303             }, Span::call_site());
304
305             let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
306                 Ok(attrs) => attrs,
307                 Err(e) => return Ok(e.write_errors().into()),
308             };
309
310             let mut enum_repr = None;
311             for attr in &input.attrs {
312                 if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
313                     let meta: Ident = attr.parse_args()?;
314                     if enum_repr.is_some() {
315                         error(attr.span(), "Cannot duplicate #[repr(...)] annotations.")?;
316                     }
317                     let repr_max_variant = match meta.to_string().as_str() {
318                         "u8" => 0xFF,
319                         "u16" => 0xFFFF,
320                         "u32" => 0xFFFFFFFF,
321                         _ => error(attr.span(), "Only `u8`, `u16` and `u32` reprs are supported.")?,
322                     };
323                     if max_variant > repr_max_variant {
324                         error(attr.span(), "A variant of this enum overflows its repr.")?;
325                     }
326                     enum_repr = Some(meta);
327                 }
328             }
329             let enum_repr = enum_repr.unwrap_or_else(|| if max_variant < 0x100 {
330                 Ident::new("u8", Span::call_site())
331             } else if max_variant < 0x10000 {
332                 Ident::new("u16", Span::call_site())
333             } else {
334                 Ident::new("u32", Span::call_site())
335             });
336
337             match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
338                 Some("u8") => if max_variant > 7 {
339                     error(input.span(), "Too many variants for u8 serialization repr.")?;
340                 }
341                 Some("u16") => if max_variant > 15 {
342                     error(input.span(), "Too many variants for u16 serialization repr.")?;
343                 }
344                 Some("u32") => if max_variant > 31 {
345                     error(input.span(), "Too many variants for u32 serialization repr.")?;
346                 }
347                 Some("u64") => if max_variant > 63 {
348                     error(input.span(), "Too many variants for u64 serialization repr.")?;
349                 }
350                 Some("u128") => if max_variant > 127 {
351                     error(input.span(), "Too many variants for u128 serialization repr.")?;
352                 }
353                 None => { }
354                 Some(x) => error(
355                     input.span(), &format!("{} is not a valid serialization repr.", x)
356                 )?,
357             };
358
359             Ok(enum_set_type_impl(
360                 &input.ident, all_variants, repr, attrs, variants, enum_repr,
361             ).into())
362         }
363     } else {
364         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
365     }
366 }
367
368 #[proc_macro_derive(EnumSetType, attributes(enumset))]
369 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
370     let input: DeriveInput = parse_macro_input!(input);
371     match derive_enum_set_type_impl(input) {
372         Ok(v) => v,
373         Err(e) => e.to_compile_error().into(),
374     }
375 }