From f07f10772a794a77b11134c8c41131bb02d820d0 Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Sun, 12 Feb 2023 18:06:29 +0100 Subject: [PATCH] derive deserialize --- derive/src/lib.rs | 230 ++++++++++++++++++++++++++++++++++++++-------- src/lib.rs | 6 +- 2 files changed, 198 insertions(+), 38 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index dd82c23..af43416 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -263,12 +263,12 @@ fn serialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> To ($name:ident, $T:ty) => { if args.$name { code = quote! { - mt_ser::MtSerialize::mt_serialize::<$T>(&{ - let mut __buf = Vec::new(); - let __writer = &mut __buf; - #code - __buf - }, __writer)?; + mt_ser::MtSerialize::mt_serialize::<$T>(&{ + let mut __buf = Vec::new(); + let __writer = &mut __buf; + #code + __buf + }, __writer)?; }; } }; @@ -285,6 +285,74 @@ fn serialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> To } } +fn deserialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr { + match res { + Ok(args) => { + let mut code = body(&args); + + macro_rules! impl_const { + ($name:ident) => { + if let Some(want) = args.$name { + code = quote! { + mt_ser::MtDeserialize::mt_deserialize::(__reader) + .and_then(|got| { + if #want == got { + #code + } else { + Err(mt_ser::DeserializeError::InvalidConst( + #want as u64, got as u64 + )) + } + }) + }; + } + }; + } + + impl_const!(const64); + impl_const!(const32); + impl_const!(const16); + impl_const!(const8); + + if args.zlib { + code = quote! { + { + let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new( + mt_ser::WrapRead(__reader)); + let __reader = &mut __owned_reader; + + #code + } + } + } + + macro_rules! impl_size { + ($name:ident, $T:ty) => { + if args.$name { + code = quote! { + $T::mt_deserialize::(__reader).and_then(|size| { + let mut __owned_reader = std::io::Read::take( + mt_ser::WrapRead(__reader), size as u64); + let __reader = &mut __owned_reader; + + #code + }) + }; + } + }; + } + + impl_size!(size8, u8); + impl_size!(size16, u16); + impl_size!(size32, u32); + impl_size!(size64, u64); + + code + } + Err(e) => return e.write_errors() + } +} + fn serialize_fields(fields: &Fields) -> TokStr { fields .iter() @@ -297,6 +365,61 @@ fn serialize_fields(fields: &Fields) -> TokStr { .collect() } +fn deserialize_fields(fields: &Fields) -> TokStr { + fields + .iter() + .map(|(ident, field)| { + let code = deserialize_args(MtArgs::from_field(field), |args| { + let cfg = get_cfg(args); + let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#cfg>(__reader) }; + + if args.default { + code = quote!{ + mt_ser::OrDefault::or_default(#code) + }; + } + + code + }); + + quote!{ + let #ident = #code?; + } + }) + .collect() +} + +fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) { + let ident_fn = match input { + syn::Fields::Unnamed(_) => |f| quote! { + mt_ser::paste::paste! { [] } + }, + _ => |f| quote! { #f }, + }; + + let fields = get_fields(input, ident_fn); + let fields_comma: TokStr = fields.iter() + .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after }); + + let fields_struct = match input { + syn::Fields::Named(_) => quote! { { #fields_comma } }, + syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) }, + syn::Fields::Unit => TokStr::new(), + }; + + (fields, fields_struct) +} + +fn get_repr(input: &syn::DeriveInput) -> syn::Type { + input + .attrs + .iter() + .find(|a| a.path.is_ident("repr")) + .expect("missing repr") + .parse_args() + .expect("invalid repr") +} + #[proc_macro_derive(MtSerialize, attributes(mt))] pub fn derive_serialize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); @@ -304,35 +427,12 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data { syn::Data::Enum(e) => { - let repr: syn::Type = input - .attrs - .iter() - .find(|a| a.path.is_ident("repr")) - .expect("missing repr") - .parse_args() - .expect("invalid repr"); - + let repr = get_repr(&input); let variants: TokStr = e.variants .iter() .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| { let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr); - - let ident_fn = match &v.fields { - syn::Fields::Unnamed(_) => |f| quote! { - mt_ser::paste::paste! { [] } - }, - _ => |f| quote! { #f }, - }; - - let fields = get_fields(&v.fields, ident_fn); - let fields_comma: TokStr = fields.iter() - .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after }); - - let destruct = match &v.fields { - syn::Fields::Named(_) => quote! { { #fields_comma } }, - syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) }, - syn::Fields::Unit => TokStr::new(), - }; + let (fields, fields_struct) = get_fields_struct(&v.fields); let code = serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields)); @@ -342,7 +442,7 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { parse_quote! { 1 + #discr }, quote! { #before - #typename::#variant #destruct => { + #typename::#variant #fields_struct => { mt_ser::MtSerialize::mt_serialize::(&((#discr) as #repr), __writer)?; #code } @@ -376,14 +476,72 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { #[proc_macro_derive(MtDeserialize, attributes(mt))] pub fn derive_deserialize(input: TokenStream) -> TokenStream { - let syn::DeriveInput { - ident: typename, .. - } = parse_macro_input!(input); + let input = parse_macro_input!(input as syn::DeriveInput); + let typename = &input.ident; + + let code = deserialize_args(MtArgs::from_derive_input(&input), |_| match &input.data { + syn::Data::Enum(e) => { + let repr = get_repr(&input); + let type_str = typename.to_string(); + + let mut consts = TokStr::new(); + let mut arms = TokStr::new(); + let mut discr = parse_quote! { 0 }; + + for v in e.variants.iter() { + discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr); + + let ident = &v.ident; + let (fields, fields_struct) = get_fields_struct(&v.fields); + let code = deserialize_args(MtArgs::from_variant(v), |_| { + let fields_code = deserialize_fields(&fields); + + quote! { + #fields_code + Ok(Self::#ident #fields_struct) + } + }); + + consts.extend(quote! { + const #ident: #repr = #discr; + }); + + arms.extend(quote! { + #ident => { #code } + }); + + discr = parse_quote! { 1 + #discr }; + } + + quote! { + #consts + + match mt_ser::MtDeserialize::mt_deserialize::(__reader)? { + #arms + x => Err(mt_ser::DeserializeError::InvalidEnumVariant(#type_str, x as u64)) + } + } + }, + syn::Data::Struct(s) => { + let (fields, fields_struct) = get_fields_struct(&s.fields); + let code = deserialize_fields(&fields); + + quote!{ + #code + Ok(Self #fields_struct) + } + }, + _ => { + panic!("only enum and struct supported"); + } + }); + quote! { #[automatically_derived] impl mt_ser::MtDeserialize for #typename { + #[allow(non_upper_case_globals)] fn mt_deserialize(__reader: &mut impl std::io::Read) -> Result { - Err(mt_ser::DeserializeError::Unimplemented) + #code } } }.into() diff --git a/src/lib.rs b/src/lib.rs index 0019d1c..f57e487 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,8 +45,10 @@ pub enum DeserializeError { TooBig(#[from] TryFromIntError), #[error("invalid UTF-16: {0}")] InvalidUtf16(#[from] std::char::DecodeUtf16Error), - #[error("unimplemented")] - Unimplemented, + #[error("invalid {0} enum variant {1}")] + InvalidEnumVariant(&'static str, u64), + #[error("invalid constant - wanted: {0} - got: {1}")] + InvalidConst(u64, u64), } impl From for DeserializeError { -- 2.44.0