]> git.lizzy.rs Git - mt_ser.git/commitdiff
Improve attributes
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Mon, 13 Feb 2023 15:10:48 +0000 (16:10 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Mon, 13 Feb 2023 15:10:48 +0000 (16:10 +0100)
derive/src/lib.rs
src/lib.rs

index 267ec00f60be5d8bc2900b2eba4540021d553166..52de2f2bc05829d5c4a4452f63ed123169cfef21 100644 (file)
@@ -158,50 +158,19 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
 #[darling(attributes(mt))]
 #[darling(default)]
 struct MtArgs {
-       const8: Option<u8>,
-       const16: Option<u16>,
-       const32: Option<u32>,
-       const64: Option<u64>,
-       size8: bool,
-       size16: bool,
-       size32: bool,
-       size64: bool,
-       len0: bool,
-       len8: bool,
-       len16: bool,
-       len32: bool,
-       len64: bool,
-       utf16: bool,
-       zlib: bool,
-       zstd: bool, // TODO
-       default: bool,
-}
+       #[darling(multiple)]
+       const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
 
-fn get_cfg(args: &MtArgs) -> syn::Type {
-       let mut ty: syn::Type = parse_quote! { mt_ser::DefCfg  };
+       #[darling(multiple)]
+       const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
 
-       if args.len0 {
-               ty = parse_quote! { () };
-       }
+       size: Option<syn::Type>, // must implement MtCfg
 
-       macro_rules! impl_len {
-               ($name:ident, $T:ty) => {
-                       if args.$name {
-                               ty = parse_quote! { $T  };
-                       }
-               };
-       }
-
-       impl_len!(len8, u8);
-       impl_len!(len16, u16);
-       impl_len!(len32, u32);
-       impl_len!(len64, u64);
+       len: Option<syn::Type>, // must implement MtCfg
 
-       if args.utf16 {
-               ty = parse_quote! { mt_ser::Utf16<#ty> };
-       }
-
-       ty
+       zlib: bool,
+       zstd: bool, // TODO
+       default: bool, // type must implement Default
 }
 
 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
@@ -242,41 +211,30 @@ fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> To
                                };
                        }
 
-                       macro_rules! impl_size {
-                               ($name:ident, $T:ty) => {
-                                       if args.$name {
-                                               code = quote! {
-                                                       mt_ser::MtSerialize::mt_serialize::<$T>(&{
-                                                               let mut __buf = Vec::new();
-                                                               let __writer = &mut __buf;
-                                                               #code
-                                                               __buf
-                                                       }, __writer)?;
-                                               };
-                                       }
+                       if let Some(size) = args.size {
+                               code = quote! {
+                                       mt_ser::MtSerialize::mt_serialize::<#size>(&{
+                                               let mut __buf = Vec::new();
+                                               let __writer = &mut __buf;
+                                               #code
+                                               __buf
+                                       }, __writer)?;
                                };
                        }
 
-                       impl_size!(size8, u8);
-                       impl_size!(size16, u16);
-                       impl_size!(size32, u32);
-                       impl_size!(size64, u64);
-
-                       macro_rules! impl_const {
-                               ($name:ident) => {
-                                       if let Some(x) = args.$name {
-                                               code = quote! {
-                                                       #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
-                                                       #code
-                                               };
-                                       }
-                               };
+                       for x in args.const_before.iter().rev() {
+                               code = quote! {
+                                       #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
+                                       #code
+                               }
                        }
 
-                       impl_const!(const8);
-                       impl_const!(const16);
-                       impl_const!(const32);
-                       impl_const!(const64);
+                       for x in args.const_after.iter() {
+                               code = quote! {
+                                       #code
+                                       #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
+                               }
+                       }
 
                        code
                }
@@ -301,50 +259,60 @@ fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) ->
                                }
                        }
 
-                       macro_rules! impl_size {
-                               ($name:ident, $T:ty) => {
-                                       if args.$name {
-                                               code = quote! {
-                                                       $T::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
-                                                               let mut __owned_reader = std::io::Read::take(
-                                                                       mt_ser::WrapRead(__reader), size as u64);
-                                                               let __reader = &mut __owned_reader;
 
-                                                               #code
-                                                       })
-                                               };
-                                       }
+                       if let Some(size) = args.size {
+                               code = quote! {
+                                       #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
+                                               let mut __owned_reader = std::io::Read::take(
+                                                       mt_ser::WrapRead(__reader), size as u64);
+                                               let __reader = &mut __owned_reader;
+
+                                               #code
+                                       })
                                };
                        }
 
-                       impl_size!(size8, u8);
-                       impl_size!(size16, u16);
-                       impl_size!(size32, u32);
-                       impl_size!(size64, u64);
-
-                       macro_rules! impl_const {
-                               ($name:ident) => {
-                                       if let Some(want) = args.$name {
-                                               code = quote! {
-                                                       mt_ser::MtDeserialize::mt_deserialize::<mt_ser::DefCfg>(__reader)
-                                                               .and_then(|got| {
-                                                                       if #want == got {
-                                                                               #code
-                                                                       } else {
-                                                                               Err(mt_ser::DeserializeError::InvalidConst(
-                                                                                       #want as u64, got as u64
-                                                                               ))
-                                                                       }
-                                                               })
-                                               };
+                       let impl_const = |value: TokStr| quote! {
+                               {
+                                       fn deserialize_same_type<T: MtDeserialize>(
+                                               _: &T,
+                                               reader: &mut impl std::io::Read
+                                       ) -> Result<T, DeserializeError> {
+                                               T::mt_deserialize::<mt_ser::DefCfg>(reader)
+                                       }
+
+                                       deserialize_same_type(&want, __reader)
+                                               .and_then(|got| {
+                                                       if want == got {
+                                                               #value
+                                                       } else {
+                                                               Err(mt_ser::DeserializeError::InvalidConst(
+                                                                       Box::new(want), Box::new(got)
+                                                               ))
+                                                       }
+                                               })
+                               }
+                       };
+
+                       for want in args.const_before.iter().rev() {
+                               let imp = impl_const(code);
+                               code = quote! {
+                                       {
+                                               let want = #want;
+                                               #imp
                                        }
                                };
                        }
 
-                       impl_const!(const8);
-                       impl_const!(const16);
-                       impl_const!(const32);
-                       impl_const!(const64);
+                       for want in args.const_after.iter() {
+                               let imp = impl_const(quote! { Ok(value) });
+                               code = quote! {
+                                       {
+                                               let want = #want;
+                                               #code.and_then(|value| { #imp })
+                                       }
+                               };
+                       }
 
                        code
                }
@@ -357,8 +325,9 @@ fn serialize_fields(fields: &Fields) -> TokStr {
                .iter()
                .map(|(ident, field)| {
                        serialize_args(MtArgs::from_field(field), |args| {
-                               let cfg = get_cfg(args);
-                               quote! { mt_ser::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; }
+                               let def = parse_quote! { mt_ser::DefCfg };
+                               let len = args.len.as_ref().unwrap_or(&def);
+                               quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
                        })
                })
                .collect()
@@ -369,8 +338,9 @@ fn deserialize_fields(fields: &Fields) -> TokStr {
                .iter()
                .map(|(ident, field)| {
                        let code = deserialize_args(MtArgs::from_field(field), |args| {
-                               let cfg = get_cfg(args);
-                               let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#cfg>(__reader) };
+                               let def = parse_quote! { mt_ser::DefCfg };
+                               let len = args.len.as_ref().unwrap_or(&def);
+                               let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
 
                                if args.default {
                                        code = quote! {
index f57e487eaa8c635e4eddee7196fe4fe85613fd20..4eaf119fdebb59f8386f316af30f554e560655e5 100644 (file)
@@ -12,6 +12,7 @@ use paste::paste as paste_macro;
 use std::{
     collections::{HashMap, HashSet},
     convert::Infallible,
+    fmt::Debug,
     io::{self, Read, Write},
     num::TryFromIntError,
     ops::Deref,
@@ -47,8 +48,8 @@ pub enum DeserializeError {
     InvalidUtf16(#[from] std::char::DecodeUtf16Error),
     #[error("invalid {0} enum variant {1}")]
     InvalidEnumVariant(&'static str, u64),
-    #[error("invalid constant - wanted: {0} - got: {1}")]
-    InvalidConst(u64, u64),
+    #[error("invalid constant - wanted: {0:?} - got: {1:?}")]
+    InvalidConst(Box<dyn Debug>, Box<dyn Debug>),
 }
 
 impl From<Infallible> for DeserializeError {
@@ -228,7 +229,7 @@ impl MtLen for () {
     }
 }
 
-pub struct Utf16<B: MtCfg>(pub B);
+pub struct Utf16<B: MtCfg = DefCfg>(pub B);
 
 impl<B: MtCfg> MtCfg for Utf16<B> {
     type Len = B::Len;