]> git.lizzy.rs Git - enumset.git/blob - enumset_derive/src/lib.rs
Implement serialize_deny_unknown, improve serialization tests.
[enumset.git] / enumset_derive / src / lib.rs
1 #![recursion_limit="256"]
2 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
3
4 extern crate darling;
5 extern crate syn;
6 extern crate proc_macro;
7 extern crate proc_macro2;
8 extern crate quote;
9
10 use darling::*;
11 use proc_macro::TokenStream;
12 use proc_macro2::{TokenStream as SynTokenStream, Literal};
13 use syn::*;
14 use syn::export::Span;
15 use syn::spanned::Spanned;
16 use quote::*;
17
18 #[cfg(feature = "nightly")]
19 fn error(span: Span, data: &str) -> TokenStream {
20     span.unstable().error(data).emit();
21     TokenStream::new()
22 }
23
24 #[cfg(not(feature = "nightly"))]
25 fn error(_: Span, data: &str) -> TokenStream {
26     panic!("{}", data)
27 }
28
29 fn enum_set_type_impl(
30     name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs,
31 ) -> SynTokenStream {
32     let typed_enumset = quote!(::enumset::EnumSet<#name>);
33     let core = quote!(::enumset::internal::core);
34     #[cfg(feature = "serde")]
35     let serde = quote!(::enumset::internal::serde);
36
37     // proc_macro2 does not support creating u128 literals.
38     let all_variants = Literal::u128_unsuffixed(all_variants);
39
40     let ops = if attrs.no_ops {
41         quote! {}
42     } else {
43         quote! {
44             impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
45                 type Output = #typed_enumset;
46                 fn sub(self, other: O) -> Self::Output {
47                     ::enumset::EnumSet::only(self) - other.into()
48                 }
49             }
50             impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
51                 type Output = #typed_enumset;
52                 fn bitand(self, other: O) -> Self::Output {
53                     ::enumset::EnumSet::only(self) & other.into()
54                 }
55             }
56             impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
57                 type Output = #typed_enumset;
58                 fn bitor(self, other: O) -> Self::Output {
59                     ::enumset::EnumSet::only(self) | other.into()
60                 }
61             }
62             impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
63                 type Output = #typed_enumset;
64                 fn bitxor(self, other: O) -> Self::Output {
65                     ::enumset::EnumSet::only(self) ^ other.into()
66                 }
67             }
68             impl #core::ops::Not for #name {
69                 type Output = #typed_enumset;
70                 fn not(self) -> Self::Output {
71                     !::enumset::EnumSet::only(self)
72                 }
73             }
74             impl #core::cmp::PartialEq<#typed_enumset> for #name {
75                 fn eq(&self, other: &#typed_enumset) -> bool {
76                     ::enumset::EnumSet::only(*self) == *other
77                 }
78             }
79         }
80     };
81
82     #[cfg(feature = "serde")]
83     let serde_ops = if attrs.serialize_as_list {
84         let expecting_str = format!("a list of {}", name);
85         quote! {
86             fn serialize<S: #serde::Serializer>(
87                 set: ::enumset::EnumSet<#name>, ser: S,
88             ) -> #core::result::Result<S::Ok, S::Error> {
89                 use #serde::ser::SerializeSeq;
90                 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
91                 for bit in set {
92                     seq.serialize_element(&bit)?;
93                 }
94                 seq.end()
95             }
96             fn deserialize<'de, D: #serde::Deserializer<'de>>(
97                 de: D,
98             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
99                 struct Visitor;
100                 impl <'de> #serde::de::Visitor<'de> for Visitor {
101                     type Value = ::enumset::EnumSet<#name>;
102                     fn expecting(
103                         &self, formatter: &mut #core::fmt::Formatter,
104                     ) -> #core::fmt::Result {
105                         write!(formatter, #expecting_str)
106                     }
107                     fn visit_seq<A>(
108                         mut self, mut seq: A,
109                     ) -> Result<Self::Value, A::Error> where A: #serde::de::SeqAccess<'de> {
110                         let mut accum = ::enumset::EnumSet::<#name>::new();
111                         while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
112                             accum |= val;
113                         }
114                         #core::prelude::v1::Ok(accum)
115                     }
116                 }
117                 de.deserialize_seq(Visitor)
118             }
119         }
120     } else {
121         let serialize_repr = attrs.serialize_repr.as_ref()
122             .map(|x| Ident::new(&x, Span::call_site()))
123             .unwrap_or(repr.clone());
124         let check_unknown = if attrs.serialize_deny_unknown {
125             quote! {
126                 if value & !#all_variants != 0 {
127                     use #serde::de::Error;
128                     let unexpected = serde::de::Unexpected::Unsigned(value as u64);
129                     return #core::prelude::v1::Err(
130                         D::Error::custom("enumset contains unknown bits")
131                     )
132                 }
133             }
134         } else {
135             quote! { }
136         };
137         quote! {
138             fn serialize<S: #serde::Serializer>(
139                 set: ::enumset::EnumSet<#name>, ser: S,
140             ) -> #core::result::Result<S::Ok, S::Error> {
141                 use #serde::Serialize;
142                 #serialize_repr::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
143             }
144             fn deserialize<'de, D: #serde::Deserializer<'de>>(
145                 de: D,
146             ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
147                 use #serde::Deserialize;
148                 let value = #serialize_repr::deserialize(de)?;
149                 #check_unknown
150                 #core::prelude::v1::Ok(::enumset::EnumSet {
151                     __enumset_underlying: (value & #all_variants) as #repr,
152                 })
153             }
154         }
155     };
156
157     #[cfg(not(feature = "serde"))]
158     let serde_ops = quote! { };
159
160     quote! {
161         unsafe impl ::enumset::EnumSetType for #name {
162             type Repr = #repr;
163             const ALL_BITS: Self::Repr = #all_variants;
164
165             fn enum_into_u8(self) -> u8 {
166                 self as u8
167             }
168             unsafe fn enum_from_u8(val: u8) -> Self {
169                 #core::mem::transmute(val)
170             }
171
172             #serde_ops
173         }
174
175         impl #core::cmp::PartialEq for #name {
176             fn eq(&self, other: &Self) -> bool {
177                 (*self as u8) == (*other as u8)
178             }
179         }
180         impl #core::cmp::Eq for #name { }
181         impl #core::clone::Clone for #name {
182             fn clone(&self) -> Self {
183                 *self
184             }
185         }
186         impl #core::marker::Copy for #name { }
187
188         #ops
189     }
190 }
191
192 #[derive(FromDeriveInput, Default)]
193 #[darling(attributes(enumset), default)]
194 struct EnumsetAttrs {
195     no_ops: bool,
196     serialize_as_list: bool,
197     serialize_deny_unknown: bool,
198     #[darling(default)]
199     serialize_repr: Option<String>,
200 }
201
202 #[proc_macro_derive(EnumSetType, attributes(enumset))]
203 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
204     let input: DeriveInput = parse_macro_input!(input);
205     if let Data::Enum(data) = &input.data {
206         if !input.generics.params.is_empty() {
207             error(input.generics.span(),
208                   "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.")
209         } else {
210             let mut all_variants = 0u128;
211             let mut max_variant = 0;
212             let mut current_variant = 0;
213             let mut has_manual_discriminant = false;
214
215             for variant in &data.variants {
216                 if let Fields::Unit = variant.fields {
217                     if let Some((_, expr)) = &variant.discriminant {
218                         if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
219                             current_variant = i.value();
220                             has_manual_discriminant = true;
221                         } else {
222                             return error(variant.span(), "Unrecognized discriminant for variant.")
223                         }
224                     }
225
226                     if current_variant >= 128 {
227                         let message = if has_manual_discriminant {
228                             "`#[derive(EnumSetType)]` only supports enum discriminants up to 127."
229                         } else {
230                             "`#[derive(EnumSetType)]` only supports enums up to 128 variants."
231                         };
232                         return error(variant.span(), message)
233                     }
234
235                     if all_variants & (1 << current_variant) != 0 {
236                         return error(variant.span(),
237                                      &format!("Duplicate enum discriminant: {}", current_variant))
238                     }
239                     all_variants |= 1 << current_variant;
240                     if current_variant > max_variant {
241                         max_variant = current_variant
242                     }
243
244                     current_variant += 1;
245                 } else {
246                     return error(variant.span(),
247                                  "`#[derive(EnumSetType)]` can only be used on C-like enums.")
248                 }
249             }
250
251             let repr = Ident::new(if max_variant <= 7 {
252                 "u8"
253             } else if max_variant <= 15 {
254                 "u16"
255             } else if max_variant <= 31 {
256                 "u32"
257             } else if max_variant <= 63 {
258                 "u64"
259             } else if max_variant <= 127 {
260                 "u128"
261             } else {
262                 panic!("max_variant > 127?")
263             }, Span::call_site());
264
265             let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
266                 Ok(attrs) => attrs,
267                 Err(e) => return e.write_errors().into(),
268             };
269
270             match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
271                 Some("u8") => if max_variant > 7 {
272                     return error(input.span(), "Too many variants for u8 serialization repr.")
273                 }
274                 Some("u16") => if max_variant > 15 {
275                     return error(input.span(), "Too many variants for u16 serialization repr.")
276                 }
277                 Some("u32") => if max_variant > 31 {
278                     return error(input.span(), "Too many variants for u32 serialization repr.")
279                 }
280                 Some("u64") => if max_variant > 63 {
281                     return error(input.span(), "Too many variants for u64 serialization repr.")
282                 }
283                 Some("u128") => if max_variant > 127 {
284                     return error(input.span(), "Too many variants for u128 serialization repr.")
285                 }
286                 None => { }
287                 Some(x) => return error(input.span(),
288                                         &format!("{} is not a valid serialization repr.", x)),
289             };
290
291             enum_set_type_impl(&input.ident, all_variants, repr, attrs).into()
292         }
293     } else {
294         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
295     }
296 }