1 #![recursion_limit="256"]
2 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
6 extern crate proc_macro;
7 extern crate proc_macro2;
11 use proc_macro::TokenStream;
12 use proc_macro2::{TokenStream as SynTokenStream, Literal};
14 use syn::export::Span;
15 use syn::spanned::Spanned;
18 #[cfg(feature = "nightly")]
19 fn error(span: Span, data: &str) -> TokenStream {
20 span.unstable().error(data).emit();
24 #[cfg(not(feature = "nightly"))]
25 fn error(_: Span, data: &str) -> TokenStream {
29 fn enum_set_type_impl(
30 name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs,
32 let typed_enumset = quote!(::enumset::EnumSet<#name>);
33 let core = quote!(::enumset::internal::core);
34 #[cfg(feature = "serde")]
35 let serde = quote!(::enumset::internal::serde);
37 // proc_macro2 does not support creating u128 literals.
38 let all_variants = Literal::u128_unsuffixed(all_variants);
40 let ops = if attrs.no_ops {
44 impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
45 type Output = #typed_enumset;
46 fn sub(self, other: O) -> Self::Output {
47 ::enumset::EnumSet::only(self) - other.into()
50 impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
51 type Output = #typed_enumset;
52 fn bitand(self, other: O) -> Self::Output {
53 ::enumset::EnumSet::only(self) & other.into()
56 impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
57 type Output = #typed_enumset;
58 fn bitor(self, other: O) -> Self::Output {
59 ::enumset::EnumSet::only(self) | other.into()
62 impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
63 type Output = #typed_enumset;
64 fn bitxor(self, other: O) -> Self::Output {
65 ::enumset::EnumSet::only(self) ^ other.into()
68 impl #core::ops::Not for #name {
69 type Output = #typed_enumset;
70 fn not(self) -> Self::Output {
71 !::enumset::EnumSet::only(self)
74 impl #core::cmp::PartialEq<#typed_enumset> for #name {
75 fn eq(&self, other: &#typed_enumset) -> bool {
76 ::enumset::EnumSet::only(*self) == *other
82 #[cfg(feature = "serde")]
83 let serde_ops = if attrs.serialize_as_list {
84 let expecting_str = format!("a list of {}", name);
86 fn serialize<S: #serde::Serializer>(
87 set: ::enumset::EnumSet<#name>, ser: S,
88 ) -> #core::result::Result<S::Ok, S::Error> {
89 use #serde::ser::SerializeSeq;
90 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
92 seq.serialize_element(&bit)?;
96 fn deserialize<'de, D: #serde::Deserializer<'de>>(
98 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
100 impl <'de> #serde::de::Visitor<'de> for Visitor {
101 type Value = ::enumset::EnumSet<#name>;
103 &self, formatter: &mut #core::fmt::Formatter,
104 ) -> #core::fmt::Result {
105 write!(formatter, #expecting_str)
108 mut self, mut seq: A,
109 ) -> Result<Self::Value, A::Error> where A: #serde::de::SeqAccess<'de> {
110 let mut accum = ::enumset::EnumSet::<#name>::new();
111 while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
114 #core::prelude::v1::Ok(accum)
117 de.deserialize_seq(Visitor)
121 let serialize_repr = attrs.serialize_repr.as_ref()
122 .map(|x| Ident::new(&x, Span::call_site()))
123 .unwrap_or(repr.clone());
124 let check_unknown = if attrs.serialize_deny_unknown {
126 if value & !#all_variants != 0 {
127 use #serde::de::Error;
128 let unexpected = serde::de::Unexpected::Unsigned(value as u64);
129 return #core::prelude::v1::Err(
130 D::Error::custom("enumset contains unknown bits")
138 fn serialize<S: #serde::Serializer>(
139 set: ::enumset::EnumSet<#name>, ser: S,
140 ) -> #core::result::Result<S::Ok, S::Error> {
141 use #serde::Serialize;
142 #serialize_repr::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
144 fn deserialize<'de, D: #serde::Deserializer<'de>>(
146 ) -> #core::result::Result<::enumset::EnumSet<#name>, D::Error> {
147 use #serde::Deserialize;
148 let value = #serialize_repr::deserialize(de)?;
150 #core::prelude::v1::Ok(::enumset::EnumSet {
151 __enumset_underlying: (value & #all_variants) as #repr,
157 #[cfg(not(feature = "serde"))]
158 let serde_ops = quote! { };
161 unsafe impl ::enumset::EnumSetType for #name {
163 const ALL_BITS: Self::Repr = #all_variants;
165 fn enum_into_u8(self) -> u8 {
168 unsafe fn enum_from_u8(val: u8) -> Self {
169 #core::mem::transmute(val)
175 impl #core::cmp::PartialEq for #name {
176 fn eq(&self, other: &Self) -> bool {
177 (*self as u8) == (*other as u8)
180 impl #core::cmp::Eq for #name { }
181 impl #core::clone::Clone for #name {
182 fn clone(&self) -> Self {
186 impl #core::marker::Copy for #name { }
192 #[derive(FromDeriveInput, Default)]
193 #[darling(attributes(enumset), default)]
194 struct EnumsetAttrs {
196 serialize_as_list: bool,
197 serialize_deny_unknown: bool,
199 serialize_repr: Option<String>,
202 #[proc_macro_derive(EnumSetType, attributes(enumset))]
203 pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
204 let input: DeriveInput = parse_macro_input!(input);
205 if let Data::Enum(data) = &input.data {
206 if !input.generics.params.is_empty() {
207 error(input.generics.span(),
208 "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.")
210 let mut all_variants = 0u128;
211 let mut max_variant = 0;
212 let mut current_variant = 0;
213 let mut has_manual_discriminant = false;
215 for variant in &data.variants {
216 if let Fields::Unit = variant.fields {
217 if let Some((_, expr)) = &variant.discriminant {
218 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
219 current_variant = i.value();
220 has_manual_discriminant = true;
222 return error(variant.span(), "Unrecognized discriminant for variant.")
226 if current_variant >= 128 {
227 let message = if has_manual_discriminant {
228 "`#[derive(EnumSetType)]` only supports enum discriminants up to 127."
230 "`#[derive(EnumSetType)]` only supports enums up to 128 variants."
232 return error(variant.span(), message)
235 if all_variants & (1 << current_variant) != 0 {
236 return error(variant.span(),
237 &format!("Duplicate enum discriminant: {}", current_variant))
239 all_variants |= 1 << current_variant;
240 if current_variant > max_variant {
241 max_variant = current_variant
244 current_variant += 1;
246 return error(variant.span(),
247 "`#[derive(EnumSetType)]` can only be used on C-like enums.")
251 let repr = Ident::new(if max_variant <= 7 {
253 } else if max_variant <= 15 {
255 } else if max_variant <= 31 {
257 } else if max_variant <= 63 {
259 } else if max_variant <= 127 {
262 panic!("max_variant > 127?")
263 }, Span::call_site());
265 let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
267 Err(e) => return e.write_errors().into(),
270 match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
271 Some("u8") => if max_variant > 7 {
272 return error(input.span(), "Too many variants for u8 serialization repr.")
274 Some("u16") => if max_variant > 15 {
275 return error(input.span(), "Too many variants for u16 serialization repr.")
277 Some("u32") => if max_variant > 31 {
278 return error(input.span(), "Too many variants for u32 serialization repr.")
280 Some("u64") => if max_variant > 63 {
281 return error(input.span(), "Too many variants for u64 serialization repr.")
283 Some("u128") => if max_variant > 127 {
284 return error(input.span(), "Too many variants for u128 serialization repr.")
287 Some(x) => return error(input.span(),
288 &format!("{} is not a valid serialization repr.", x)),
291 enum_set_type_impl(&input.ident, all_variants, repr, attrs).into()
294 error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")