]> git.lizzy.rs Git - mt_ser.git/commitdiff
derive deserialize
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Sun, 12 Feb 2023 17:06:29 +0000 (18:06 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Sun, 12 Feb 2023 17:06:29 +0000 (18:06 +0100)
derive/src/lib.rs
src/lib.rs

index dd82c23cf3e60e47a8ad066b4f71a212ff7831c3..af43416ec8a47396f01148b72d02bb57310c65d1 100644 (file)
@@ -263,12 +263,12 @@ fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> To
                                ($name:ident, $T:ty) => {
                                        if args.$name {
                                                code = quote! {
-                                                               mt_ser::MtSerialize::mt_serialize::<$T>(&{
-                                                                       let mut __buf = Vec::new();
-                                                                       let __writer = &mut __buf;
-                                                                       #code
-                                                                       __buf
-                                                               }, __writer)?;
+                                                       mt_ser::MtSerialize::mt_serialize::<$T>(&{
+                                                               let mut __buf = Vec::new();
+                                                               let __writer = &mut __buf;
+                                                               #code
+                                                               __buf
+                                                       }, __writer)?;
                                                };
                                        }
                                };
@@ -285,6 +285,74 @@ fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> To
        }
 }
 
+fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
+       match res {
+               Ok(args) => {
+                       let mut code = body(&args);
+
+                       macro_rules! impl_const {
+                               ($name:ident) => {
+                                       if let Some(want) = args.$name {
+                                               code = quote! {
+                                                       mt_ser::MtDeserialize::mt_deserialize::<mt_ser::DefCfg>(__reader)
+                                                               .and_then(|got| {
+                                                                       if #want == got {
+                                                                               #code
+                                                                       } else {
+                                                                               Err(mt_ser::DeserializeError::InvalidConst(
+                                                                                       #want as u64, got as u64
+                                                                               ))
+                                                                       }
+                                                               })
+                                               };
+                                       }
+                               };
+                       }
+
+                       impl_const!(const64);
+                       impl_const!(const32);
+                       impl_const!(const16);
+                       impl_const!(const8);
+
+                       if args.zlib {
+                               code = quote! {
+                                       {
+                                               let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
+                                                       mt_ser::WrapRead(__reader));
+                                               let __reader = &mut __owned_reader;
+
+                                               #code
+                                       }
+                               }
+                       }
+
+                       macro_rules! impl_size {
+                               ($name:ident, $T:ty) => {
+                                       if args.$name {
+                                               code = quote! {
+                                                       $T::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
+                                                               let mut __owned_reader = std::io::Read::take(
+                                                                       mt_ser::WrapRead(__reader), size as u64);
+                                                               let __reader = &mut __owned_reader;
+
+                                                               #code
+                                                       })
+                                               };
+                                       }
+                               };
+                       }
+
+                       impl_size!(size8, u8);
+                       impl_size!(size16, u16);
+                       impl_size!(size32, u32);
+                       impl_size!(size64, u64);
+
+                       code
+               }
+               Err(e) => return e.write_errors()
+       }
+}
+
 fn serialize_fields(fields: &Fields) -> TokStr {
        fields
                .iter()
@@ -297,6 +365,61 @@ fn serialize_fields(fields: &Fields) -> TokStr {
                .collect()
 }
 
+fn deserialize_fields(fields: &Fields) -> TokStr {
+       fields
+               .iter()
+               .map(|(ident, field)| {
+                       let code = deserialize_args(MtArgs::from_field(field), |args| {
+                               let cfg = get_cfg(args);
+                               let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#cfg>(__reader) };
+
+                               if args.default {
+                                       code = quote!{
+                                               mt_ser::OrDefault::or_default(#code)
+                                       };
+                               }
+
+                               code
+                       });
+
+                       quote!{
+                               let #ident = #code?;
+                       }
+               })
+               .collect()
+}
+
+fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
+       let ident_fn = match input {
+               syn::Fields::Unnamed(_) => |f| quote! {
+                       mt_ser::paste::paste! { [<field_ #f>] }
+               },
+               _ => |f| quote! { #f },
+       };
+
+       let fields = get_fields(input, ident_fn);
+       let fields_comma: TokStr = fields.iter()
+               .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
+
+       let fields_struct = match input {
+               syn::Fields::Named(_) => quote! { { #fields_comma } },
+               syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
+               syn::Fields::Unit => TokStr::new(),
+       };
+
+       (fields, fields_struct)
+}
+
+fn get_repr(input: &syn::DeriveInput) -> syn::Type {
+       input
+               .attrs
+               .iter()
+               .find(|a| a.path.is_ident("repr"))
+               .expect("missing repr")
+               .parse_args()
+               .expect("invalid repr")
+}
+
 #[proc_macro_derive(MtSerialize, attributes(mt))]
 pub fn derive_serialize(input: TokenStream) -> TokenStream {
        let input = parse_macro_input!(input as syn::DeriveInput);
@@ -304,35 +427,12 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
 
        let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
                syn::Data::Enum(e) => {
-                       let repr: syn::Type = input
-                               .attrs
-                               .iter()
-                               .find(|a| a.path.is_ident("repr"))
-                               .expect("missing repr")
-                               .parse_args()
-                               .expect("invalid repr");
-
+                       let repr = get_repr(&input);
                        let variants: TokStr = e.variants
                                .iter()
                                .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| {
                                        let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
-
-                                       let ident_fn = match &v.fields {
-                                               syn::Fields::Unnamed(_) => |f| quote! {
-                                                       mt_ser::paste::paste! { [<field_ #f>] }
-                                               },
-                                               _ => |f| quote! { #f },
-                                       };
-
-                                       let fields = get_fields(&v.fields, ident_fn);
-                                       let fields_comma: TokStr = fields.iter()
-                                               .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
-
-                                       let destruct = match &v.fields {
-                                               syn::Fields::Named(_) => quote! { { #fields_comma } },
-                                               syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
-                                               syn::Fields::Unit => TokStr::new(),
-                                       };
+                                       let (fields, fields_struct) = get_fields_struct(&v.fields);
 
                                        let code = serialize_args(MtArgs::from_variant(v), |_|
                                                serialize_fields(&fields));
@@ -342,7 +442,7 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
                                                parse_quote! { 1 + #discr },
                                                quote! {
                                                        #before
-                                                       #typename::#variant #destruct => {
+                                                       #typename::#variant #fields_struct => {
                                                                mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
                                                                #code
                                                        }
@@ -376,14 +476,72 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
 
 #[proc_macro_derive(MtDeserialize, attributes(mt))]
 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
-       let syn::DeriveInput {
-               ident: typename, ..
-       } = parse_macro_input!(input);
+       let input = parse_macro_input!(input as syn::DeriveInput);
+       let typename = &input.ident;
+
+       let code = deserialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
+               syn::Data::Enum(e) => {
+                       let repr = get_repr(&input);
+                       let type_str = typename.to_string();
+
+                       let mut consts = TokStr::new();
+                       let mut arms = TokStr::new();
+                       let mut discr = parse_quote! { 0 };
+
+                       for v in e.variants.iter() {
+                               discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
+
+                               let ident = &v.ident;
+                               let (fields, fields_struct) = get_fields_struct(&v.fields);
+                               let code = deserialize_args(MtArgs::from_variant(v), |_| {
+                                       let fields_code = deserialize_fields(&fields);
+
+                                       quote! {
+                                               #fields_code
+                                               Ok(Self::#ident #fields_struct)
+                                       }
+                               });
+
+                               consts.extend(quote! {
+                                       const #ident: #repr = #discr;
+                               });
+
+                               arms.extend(quote! {
+                                       #ident => { #code }
+                               });
+
+                               discr = parse_quote! { 1 + #discr };
+                       }
+
+                       quote! {
+                               #consts
+
+                               match mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)? {
+                                       #arms
+                                       x => Err(mt_ser::DeserializeError::InvalidEnumVariant(#type_str, x as u64))
+                               }
+                       }
+               },
+               syn::Data::Struct(s) => {
+                       let (fields, fields_struct) = get_fields_struct(&s.fields);
+                       let code = deserialize_fields(&fields);
+
+                       quote!{
+                               #code
+                               Ok(Self #fields_struct)
+                       }
+               },
+               _ => {
+                       panic!("only enum and struct supported");
+               }
+       });
+
        quote! {
                #[automatically_derived]
                impl mt_ser::MtDeserialize for #typename {
+                       #[allow(non_upper_case_globals)]
                        fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
-                               Err(mt_ser::DeserializeError::Unimplemented)
+                               #code
                        }
                }
        }.into()
index 0019d1c0b63e023abc5befa8b21e824a8e972d6b..f57e487eaa8c635e4eddee7196fe4fe85613fd20 100644 (file)
@@ -45,8 +45,10 @@ pub enum DeserializeError {
     TooBig(#[from] TryFromIntError),
     #[error("invalid UTF-16: {0}")]
     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
-    #[error("unimplemented")]
-    Unimplemented,
+    #[error("invalid {0} enum variant {1}")]
+    InvalidEnumVariant(&'static str, u64),
+    #[error("invalid constant - wanted: {0} - got: {1}")]
+    InvalidConst(u64, u64),
 }
 
 impl From<Infallible> for DeserializeError {