]> git.lizzy.rs Git - enumset.git/commitdiff
Use u32 instead of u8 for variant count related methods.
authorLymia Aluysia <lymia@lymiahugs.com>
Wed, 1 Apr 2020 17:47:06 +0000 (10:47 -0700)
committerLymia Aluysia <lymia@lymiahugs.com>
Wed, 1 Apr 2020 17:47:06 +0000 (10:47 -0700)
enumset/src/lib.rs
enumset/tests/ops.rs
enumset_derive/src/lib.rs

index 92f0876d8ba04ce94e0b390ee0e327bb26ad0a40..252f3768452be614b70f38044f5efdd5fa99893f 100644 (file)
@@ -105,8 +105,8 @@ pub mod internal {
     pub unsafe trait EnumSetTypePrivate {
         type Repr: EnumSetTypeRepr;
         const ALL_BITS: Self::Repr;
-        fn enum_into_u8(self) -> u8;
-        unsafe fn enum_from_u8(val: u8) -> Self;
+        fn enum_into_u32(self) -> u32;
+        unsafe fn enum_from_u32(val: u32) -> Self;
 
         #[cfg(feature = "serde")]
         fn serialize<S: serde::Serializer>(set: EnumSet<Self>, ser: S) -> Result<S::Ok, S::Error>
@@ -128,7 +128,7 @@ mod private {
         AsPrimitive<u8> + AsPrimitive<u16> + AsPrimitive<u32> + AsPrimitive<u64> +
         AsPrimitive<u128> + AsPrimitive<usize>
     {
-        const WIDTH: u8;
+        const WIDTH: u32;
 
         fn from_u8(v: u8) -> Self;
         fn from_u16(v: u16) -> Self;
@@ -140,7 +140,7 @@ mod private {
     macro_rules! prim {
         ($name:ty, $width:expr) => {
             impl EnumSetTypeRepr for $name {
-                const WIDTH: u8 = $width;
+                const WIDTH: u32 = $width;
                 fn from_u8(v: u8) -> Self { v.as_() }
                 fn from_u16(v: u16) -> Self { v.as_() }
                 fn from_u32(v: u32) -> Self { v.as_() }
@@ -238,15 +238,15 @@ pub struct EnumSet<T: EnumSetType> {
     pub __enumset_underlying: T::Repr
 }
 impl <T: EnumSetType> EnumSet<T> {
-    fn mask(bit: u8) -> T::Repr {
+    fn mask(bit: u32) -> T::Repr {
         Shl::<usize>::shl(T::Repr::one(), bit as usize)
     }
-    fn has_bit(&self, bit: u8) -> bool {
+    fn has_bit(&self, bit: u32) -> bool {
         let mask = Self::mask(bit);
         self.__enumset_underlying & mask == mask
     }
-    fn partial_bits(bits: u8) -> T::Repr {
-        T::Repr::one().checked_shl(bits.into())
+    fn partial_bits(bits: u32) -> T::Repr {
+        T::Repr::one().checked_shl(bits as u32)
             .unwrap_or(T::Repr::zero())
             .wrapping_sub(&T::Repr::one())
     }
@@ -263,7 +263,7 @@ impl <T: EnumSetType> EnumSet<T> {
 
     /// Returns an `EnumSet` containing a single element.
     pub fn only(t: T) -> Self {
-        EnumSet { __enumset_underlying: Self::mask(t.enum_into_u8()) }
+        EnumSet { __enumset_underlying: Self::mask(t.enum_into_u32()) }
     }
 
     /// Creates an empty `EnumSet`.
@@ -283,16 +283,16 @@ impl <T: EnumSetType> EnumSet<T> {
     ///
     /// This is the same as [`EnumSet::variant_count`] except in enums with "sparse" variants.
     /// (e.g. `enum Foo { A = 10, B = 20 }`)
-    pub fn bit_width() -> u8 {
-        T::Repr::WIDTH - T::ALL_BITS.leading_zeros() as u8
+    pub fn bit_width() -> u32 {
+        T::Repr::WIDTH - T::ALL_BITS.leading_zeros()
     }
 
     /// The number of valid variants that this type can contain.
     ///
     /// This is the same as [`EnumSet::bit_width`] except in enums with "sparse" variants.
     /// (e.g. `enum Foo { A = 10, B = 20 }`)
-    pub fn variant_count() -> u8 {
-        T::ALL_BITS.count_ones() as u8
+    pub fn variant_count() -> u32 {
+        T::ALL_BITS.count_ones()
     }
 
     /// Returns the number of elements in this set.
@@ -348,7 +348,7 @@ impl <T: EnumSetType> EnumSet<T> {
 
     /// Checks whether this set contains a value.
     pub fn contains(&self, value: T) -> bool {
-        self.has_bit(value.enum_into_u8())
+        self.has_bit(value.enum_into_u32())
     }
 
     /// Adds a value to this set.
@@ -358,13 +358,13 @@ impl <T: EnumSetType> EnumSet<T> {
     /// If the set did have this value present, `false` is returned.
     pub fn insert(&mut self, value: T) -> bool {
         let contains = !self.contains(value);
-        self.__enumset_underlying = self.__enumset_underlying | Self::mask(value.enum_into_u8());
+        self.__enumset_underlying = self.__enumset_underlying | Self::mask(value.enum_into_u32());
         contains
     }
     /// Removes a value from this set. Returns whether the value was present in the set.
     pub fn remove(&mut self, value: T) -> bool {
         let contains = self.contains(value);
-        self.__enumset_underlying = self.__enumset_underlying & !Self::mask(value.enum_into_u8());
+        self.__enumset_underlying = self.__enumset_underlying & !Self::mask(value.enum_into_u32());
         contains
     }
 
@@ -553,7 +553,7 @@ impl <T: EnumSetType> From<T> for EnumSet<T> {
 
 impl <T: EnumSetType> PartialEq<T> for EnumSet<T> {
     fn eq(&self, other: &T) -> bool {
-        self.__enumset_underlying == EnumSet::<T>::mask(other.enum_into_u8())
+        self.__enumset_underlying == EnumSet::<T>::mask(other.enum_into_u32())
     }
 }
 impl <T: EnumSetType + Debug> Debug for EnumSet<T> {
@@ -602,7 +602,7 @@ impl <'de, T: EnumSetType> Deserialize<'de> for EnumSet<T> {
 
 /// The iterator used by [`EnumSet`]s.
 #[derive(Clone, Debug)]
-pub struct EnumSetIter<T: EnumSetType>(EnumSet<T>, u8);
+pub struct EnumSetIter<T: EnumSetType>(EnumSet<T>, u32);
 impl <T: EnumSetType> Iterator for EnumSetIter<T> {
     type Item = T;
 
@@ -611,7 +611,7 @@ impl <T: EnumSetType> Iterator for EnumSetIter<T> {
             let bit = self.1;
             self.1 += 1;
             if self.0.has_bit(bit) {
-                return unsafe { Some(T::enum_from_u8(bit)) }
+                return unsafe { Some(T::enum_from_u32(bit)) }
             }
         }
         None
@@ -668,7 +668,7 @@ macro_rules! enum_set {
         $crate::internal::EnumSetSameTypeHack {
             unified: &[$($value,)*],
             enum_set: $crate::EnumSet {
-                __enumset_underlying: 0 $(| (1 << ($value as u8)))*
+                __enumset_underlying: 0 $(| (1 << ($value as u32)))*
             },
         }.enum_set
     };
index e82ad640bd8e749ee969251c30ccc22500fbd9da..b5fd0e55378a4e6eea9684382302eb10ba72afd0 100644 (file)
@@ -48,6 +48,12 @@ pub enum SparseEnum {
     A = 0xA, B = 20, C = 30, D = 40, E = 50, F = 60, G = 70, H = 80,
 }
 
+#[repr(u32)]
+#[derive(EnumSetType, Debug)]
+pub enum ReprEnum {
+    A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z,
+}
+
 macro_rules! test_variants {
     ($enum_name:ident $all_empty_test:ident $($variant:ident,)*) => {
         #[test]
@@ -300,6 +306,7 @@ tests!(large_enum, test_enum!(LargeEnum, 16));
 tests!(enum8, test_enum!(Enum8, 1));
 tests!(enum128, test_enum!(Enum128, 16));
 tests!(sparse_enum, test_enum!(SparseEnum, 16));
+tests!(repr_enum, test_enum!(ReprEnum, 4));
 
 #[derive(EnumSetType, Debug)]
 pub enum ThresholdEnum {
index b4dcad193d5816724a301443d63a5c38ef003ee4..99b5f99c411153fb5e43e3b1784fd0108c048ab8 100644 (file)
@@ -1,29 +1,25 @@
 #![recursion_limit="256"]
 #![cfg_attr(feature = "nightly", feature(proc_macro_diagnostic))]
 
+// TODO: Read #[repr(...)] attributes.
+
 extern crate proc_macro;
 
 use darling::*;
 use proc_macro::TokenStream;
 use proc_macro2::{TokenStream as SynTokenStream, Literal};
-use syn::*;
+use syn::{*, Result, Error};
 use syn::export::Span;
 use syn::spanned::Spanned;
 use quote::*;
 
-#[cfg(feature = "nightly")]
-fn error(span: Span, data: &str) -> TokenStream {
-    span.unstable().error(data).emit();
-    TokenStream::new()
-}
-
-#[cfg(not(feature = "nightly"))]
-fn error(_: Span, data: &str) -> TokenStream {
-    panic!("{}", data)
+fn error<T>(span: Span, message: &str) -> Result<T> {
+    Err(Error::new(span, message))
 }
 
 fn enum_set_type_impl(
     name: &Ident, all_variants: u128, repr: Ident, attrs: EnumsetAttrs, variants: Vec<Ident>,
+    enum_repr: Ident,
 ) -> SynTokenStream {
     let is_uninhabited = variants.is_empty();
     let is_zst = variants.len() == 1;
@@ -157,30 +153,30 @@ fn enum_set_type_impl(
 
     let into_impl = if is_uninhabited {
         quote! {
-            fn enum_into_u8(self) -> u8 {
+            fn enum_into_u32(self) -> u32 {
                 panic!(concat!(stringify!(#name), " is uninhabited."))
             }
-            unsafe fn enum_from_u8(val: u8) -> Self {
+            unsafe fn enum_from_u32(val: u32) -> Self {
                 panic!(concat!(stringify!(#name), " is uninhabited."))
             }
         }
     } else if is_zst {
         let variant = &variants[0];
         quote! {
-            fn enum_into_u8(self) -> u8 {
-                self as u8
+            fn enum_into_u32(self) -> u32 {
+                self as u32
             }
-            unsafe fn enum_from_u8(val: u8) -> Self {
+            unsafe fn enum_from_u32(val: u32) -> Self {
                 #name::#variant
             }
         }
     } else {
         quote! {
-            fn enum_into_u8(self) -> u8 {
-                self as u8
+            fn enum_into_u32(self) -> u32 {
+                self as u32
             }
-            unsafe fn enum_from_u8(val: u8) -> Self {
-                #core::mem::transmute(val)
+            unsafe fn enum_from_u32(val: u32) -> Self {
+                #core::mem::transmute(val as #enum_repr)
             }
         }
     };
@@ -188,7 +184,7 @@ fn enum_set_type_impl(
     let eq_impl = if is_uninhabited {
         quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
     } else {
-        quote!((*self as u8) == (*other as u8))
+        quote!((*self as u32) == (*other as u32))
     };
 
     quote! {
@@ -228,17 +224,17 @@ struct EnumsetAttrs {
     serialize_repr: Option<String>,
 }
 
-#[proc_macro_derive(EnumSetType, attributes(enumset))]
-pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
-    let input: DeriveInput = parse_macro_input!(input);
+fn derive_enum_set_type_impl(input: DeriveInput) -> Result<TokenStream> {
     if let Data::Enum(data) = &input.data {
         if !input.generics.params.is_empty() {
-            error(input.generics.span(),
-                  "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.")
+            error(
+                input.generics.span(),
+                "`#[derive(EnumSetType)]` cannot be used on enums with type parameters."
+            )
         } else {
             let mut all_variants = 0u128;
-            let mut max_variant = 0;
-            let mut current_variant = 0;
+            let mut max_variant = 0u32;
+            let mut current_variant = 0u32;
             let mut has_manual_discriminant = false;
             let mut variants = Vec::new();
 
@@ -248,28 +244,36 @@ pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
                         if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
                             current_variant = match i.base10_parse() {
                                 Ok(val) => val,
-                                Err(_) => return error(expr.span(), "Error parsing discriminant."),
+                                Err(_) => error(
+                                    expr.span(), "Could not parse discriminant as u32.",
+                                )?,
                             };
                             has_manual_discriminant = true;
                         } else {
-                            return error(variant.span(), "Unrecognized discriminant for variant.")
+                            error(
+                                variant.span(), "Unrecognized discriminant for variant."
+                            )?;
                         }
                     }
 
                     if current_variant >= 128 {
                         let message = if has_manual_discriminant {
-                            "`#[derive(EnumSetType)]` only supports enum discriminants up to 127."
+                            "`#[derive(EnumSetType)]` currently only supports \
+                             enum discriminants up to 127."
                         } else {
-                            "`#[derive(EnumSetType)]` only supports enums up to 128 variants."
+                            "`#[derive(EnumSetType)]` currently only supports \
+                             enums up to 128 variants."
                         };
-                        return error(variant.span(), message)
+                        error(variant.span(), message)?;
                     }
 
-                    if all_variants & (1 << current_variant) != 0 {
-                        return error(variant.span(),
-                                     &format!("Duplicate enum discriminant: {}", current_variant))
+                    if all_variants & (1 << current_variant as u128) != 0 {
+                        error(
+                            variant.span(),
+                            &format!("Duplicate enum discriminant: {}", current_variant)
+                        )?;
                     }
-                    all_variants |= 1 << current_variant;
+                    all_variants |= 1 << current_variant as u128;
                     if current_variant > max_variant {
                         max_variant = current_variant
                     }
@@ -277,8 +281,10 @@ pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
                     variants.push(variant.ident.clone());
                     current_variant += 1;
                 } else {
-                    return error(variant.span(),
-                                 "`#[derive(EnumSetType)]` can only be used on C-like enums.")
+                    error(
+                        variant.span(),
+                        "`#[derive(EnumSetType)]` can only be used on C-like enums."
+                    )?;
                 }
             }
 
@@ -298,33 +304,72 @@ pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
 
             let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
                 Ok(attrs) => attrs,
-                Err(e) => return e.write_errors().into(),
+                Err(e) => return Ok(e.write_errors().into()),
             };
 
+            let mut enum_repr = None;
+            for attr in &input.attrs {
+                if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
+                    let meta: Ident = attr.parse_args()?;
+                    if enum_repr.is_some() {
+                        error(attr.span(), "Cannot duplicate #[repr(...)] annotations.")?;
+                    }
+                    let repr_max_variant = match meta.to_string().as_str() {
+                        "u8" => 0xFF,
+                        "u16" => 0xFFFF,
+                        "u32" => 0xFFFFFFFF,
+                        _ => error(attr.span(), "Only `u8`, `u16` and `u32` reprs are supported.")?,
+                    };
+                    if max_variant > repr_max_variant {
+                        error(attr.span(), "A variant of this enum overflows its repr.")?;
+                    }
+                    enum_repr = Some(meta);
+                }
+            }
+            let enum_repr = enum_repr.unwrap_or_else(|| if max_variant < 0x100 {
+                Ident::new("u8", Span::call_site())
+            } else if max_variant < 0x10000 {
+                Ident::new("u16", Span::call_site())
+            } else {
+                Ident::new("u32", Span::call_site())
+            });
+
             match attrs.serialize_repr.as_ref().map(|x| x.as_str()) {
                 Some("u8") => if max_variant > 7 {
-                    return error(input.span(), "Too many variants for u8 serialization repr.")
+                    error(input.span(), "Too many variants for u8 serialization repr.")?;
                 }
                 Some("u16") => if max_variant > 15 {
-                    return error(input.span(), "Too many variants for u16 serialization repr.")
+                    error(input.span(), "Too many variants for u16 serialization repr.")?;
                 }
                 Some("u32") => if max_variant > 31 {
-                    return error(input.span(), "Too many variants for u32 serialization repr.")
+                    error(input.span(), "Too many variants for u32 serialization repr.")?;
                 }
                 Some("u64") => if max_variant > 63 {
-                    return error(input.span(), "Too many variants for u64 serialization repr.")
+                    error(input.span(), "Too many variants for u64 serialization repr.")?;
                 }
                 Some("u128") => if max_variant > 127 {
-                    return error(input.span(), "Too many variants for u128 serialization repr.")
+                    error(input.span(), "Too many variants for u128 serialization repr.")?;
                 }
                 None => { }
-                Some(x) => return error(input.span(),
-                                        &format!("{} is not a valid serialization repr.", x)),
+                Some(x) => error(
+                    input.span(), &format!("{} is not a valid serialization repr.", x)
+                )?,
             };
 
-            enum_set_type_impl(&input.ident, all_variants, repr, attrs, variants).into()
+            Ok(enum_set_type_impl(
+                &input.ident, all_variants, repr, attrs, variants, enum_repr,
+            ).into())
         }
     } else {
         error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
     }
 }
+
+#[proc_macro_derive(EnumSetType, attributes(enumset))]
+pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
+    let input: DeriveInput = parse_macro_input!(input);
+    match derive_enum_set_type_impl(input) {
+        Ok(v) => v,
+        Err(e) => e.to_compile_error().into(),
+    }
+}
\ No newline at end of file