]> git.lizzy.rs Git - enumset.git/blobdiff - enumset_derive/src/lib.rs
Add serialize_as_map
[enumset.git] / enumset_derive / src / lib.rs
index 5e1afa05b1bd2ab5c471d22924b3f3eaee3fa219..04fc0df77539af6418c256b994290b57c23fc2e4 100644 (file)
@@ -24,6 +24,7 @@ struct EnumsetAttrs {
     #[darling(default)]
     repr: Option<String>,
     serialize_as_list: bool,
+    serialize_as_map: bool,
     serialize_deny_unknown: bool,
     #[darling(default)]
     serialize_repr: Option<String>,
@@ -72,6 +73,8 @@ struct EnumSetInfo {
     no_super_impls: bool,
     /// Serialize the enum as a list.
     serialize_as_list: bool,
+    /// Serialize the enum as a map.
+    serialize_as_map: bool,
     /// Disallow unknown bits while deserializing the enum.
     serialize_deny_unknown: bool,
 }
@@ -94,6 +97,7 @@ impl EnumSetInfo {
             no_ops: attrs.no_ops,
             no_super_impls: attrs.no_super_impls,
             serialize_as_list: attrs.serialize_as_list,
+            serialize_as_map: attrs.serialize_as_map,
             serialize_deny_unknown: attrs.serialize_deny_unknown,
         }
     }
@@ -185,39 +189,35 @@ impl EnumSetInfo {
     }
     /// Validate the enumset type.
     fn validate(&self) -> Result<()> {
-        // Check if all bits of the bitset can fit in the serialization representation.
-        if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
-            let is_overflowed = match explicit_serde_repr.to_string().as_str() {
-                "u8" => self.max_discrim >= 8,
-                "u16" => self.max_discrim >= 16,
-                "u32" => self.max_discrim >= 32,
-                "u64" => self.max_discrim >= 64,
-                "u128" => self.max_discrim >= 128,
+        fn do_check(ty: &str, max_discrim: u32, what: &str) -> Result<()> {
+            let is_overflowed = match ty {
+                "u8" => max_discrim >= 8,
+                "u16" => max_discrim >= 16,
+                "u32" => max_discrim >= 32,
+                "u64" => max_discrim >= 64,
+                "u128" => max_discrim >= 128,
                 _ => error(
                     Span::call_site(),
-                    "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr.",
+                    format!(
+                        "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for {}.",
+                        what
+                    ),
                 )?,
             };
             if is_overflowed {
-                error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
+                error(Span::call_site(), format!("{} cannot be smaller than bitset.", what))?;
             }
+            Ok(())
         }
+
+        // Check if all bits of the bitset can fit in the serialization representation.
+        if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
+            do_check(&explicit_serde_repr.to_string(), self.max_discrim, "serialize_repr")?;
+        }
+
         // Check if all bits of the bitset can fit in the memory representation, if one was given.
         if let Some(explicit_mem_repr) = &self.explicit_mem_repr {
-            let is_overflowed = match explicit_mem_repr.to_string().as_str() {
-                "u8" => self.max_discrim >= 8,
-                "u16" => self.max_discrim >= 16,
-                "u32" => self.max_discrim >= 32,
-                "u64" => self.max_discrim >= 64,
-                "u128" => self.max_discrim >= 128,
-                _ => error(
-                    Span::call_site(),
-                    "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for repr.",
-                )?,
-            };
-            if is_overflowed {
-                error(Span::call_site(), "repr cannot be smaller than bitset.")?;
-            }
+            do_check(&explicit_mem_repr.to_string(), self.max_discrim, "repr")?;
         }
         Ok(())
     }
@@ -379,6 +379,45 @@ fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
                 de.deserialize_seq(Visitor)
             }
         }
+    } else if info.serialize_as_map {
+        let expecting_str = format!("a map from {} to bool", name);
+        quote! {
+            fn serialize<S: #serde::Serializer>(
+                set: #enumset::EnumSet<#name>, ser: S,
+            ) -> #core::result::Result<S::Ok, S::Error> {
+                use #serde::ser::SerializeMap;
+                let mut map = ser.serialize_map(#core::prelude::v1::Some(set.len()))?;
+                for bit in set {
+                    map.serialize_entry(&bit, &true)?;
+                }
+                map.end()
+            }
+            fn deserialize<'de, D: #serde::Deserializer<'de>>(
+                de: D,
+            ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
+                struct Visitor;
+                impl <'de> #serde::de::Visitor<'de> for Visitor {
+                    type Value = #enumset::EnumSet<#name>;
+                    fn expecting(
+                        &self, formatter: &mut #core::fmt::Formatter,
+                    ) -> #core::fmt::Result {
+                        write!(formatter, #expecting_str)
+                    }
+                    fn visit_map<A>(
+                        mut self, mut map: A,
+                    ) -> #core::result::Result<Self::Value, A::Error> where
+                        A: #serde::de::MapAccess<'de>
+                    {
+                        let mut accum = #enumset::EnumSet::<#name>::new();
+                        while let #core::prelude::v1::Some((val, true)) = map.next_entry::<#name, bool>()? {
+                            accum |= val;
+                        }
+                        #core::prelude::v1::Ok(accum)
+                    }
+                }
+                de.deserialize_map(Visitor)
+            }
+        }
     } else {
         let serialize_repr = info.serde_repr();
         let check_unknown = if info.serialize_deny_unknown {