From 7fa81df83339c451178e27aa26cf6f1a331ccd4a Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Mon, 13 Feb 2023 16:10:48 +0100 Subject: [PATCH] Improve attributes --- derive/src/lib.rs | 190 +++++++++++++++++++--------------------------- src/lib.rs | 7 +- 2 files changed, 84 insertions(+), 113 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 267ec00..52de2f2 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -158,50 +158,19 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream { #[darling(attributes(mt))] #[darling(default)] struct MtArgs { - const8: Option, - const16: Option, - const32: Option, - const64: Option, - size8: bool, - size16: bool, - size32: bool, - size64: bool, - len0: bool, - len8: bool, - len16: bool, - len32: bool, - len64: bool, - utf16: bool, - zlib: bool, - zstd: bool, // TODO - default: bool, -} + #[darling(multiple)] + const_before: Vec, // must implement MtSerialize + MtDeserialize + PartialEq -fn get_cfg(args: &MtArgs) -> syn::Type { - let mut ty: syn::Type = parse_quote! { mt_ser::DefCfg }; + #[darling(multiple)] + const_after: Vec, // must implement MtSerialize + MtDeserialize + PartialEq - if args.len0 { - ty = parse_quote! { () }; - } + size: Option, // must implement MtCfg - macro_rules! impl_len { - ($name:ident, $T:ty) => { - if args.$name { - ty = parse_quote! { $T }; - } - }; - } - - impl_len!(len8, u8); - impl_len!(len16, u16); - impl_len!(len32, u32); - impl_len!(len64, u64); + len: Option, // must implement MtCfg - if args.utf16 { - ty = parse_quote! { mt_ser::Utf16<#ty> }; - } - - ty + zlib: bool, + zstd: bool, // TODO + default: bool, // type must implement Default } type Fields<'a> = Vec<(TokStr, &'a syn::Field)>; @@ -242,41 +211,30 @@ fn serialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> To }; } - macro_rules! impl_size { - ($name:ident, $T:ty) => { - if args.$name { - code = quote! { - mt_ser::MtSerialize::mt_serialize::<$T>(&{ - let mut __buf = Vec::new(); - let __writer = &mut __buf; - #code - __buf - }, __writer)?; - }; - } + if let Some(size) = args.size { + code = quote! { + mt_ser::MtSerialize::mt_serialize::<#size>(&{ + let mut __buf = Vec::new(); + let __writer = &mut __buf; + #code + __buf + }, __writer)?; }; } - impl_size!(size8, u8); - impl_size!(size16, u16); - impl_size!(size32, u32); - impl_size!(size64, u64); - - macro_rules! impl_const { - ($name:ident) => { - if let Some(x) = args.$name { - code = quote! { - #x.mt_serialize::(__writer)?; - #code - }; - } - }; + for x in args.const_before.iter().rev() { + code = quote! { + #x.mt_serialize::(__writer)?; + #code + } } - impl_const!(const8); - impl_const!(const16); - impl_const!(const32); - impl_const!(const64); + for x in args.const_after.iter() { + code = quote! { + #code + #x.mt_serialize::(__writer)?; + } + } code } @@ -301,50 +259,60 @@ fn deserialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> } } - macro_rules! impl_size { - ($name:ident, $T:ty) => { - if args.$name { - code = quote! { - $T::mt_deserialize::(__reader).and_then(|size| { - let mut __owned_reader = std::io::Read::take( - mt_ser::WrapRead(__reader), size as u64); - let __reader = &mut __owned_reader; - #code - }) - }; - } + if let Some(size) = args.size { + code = quote! { + #size::mt_deserialize::(__reader).and_then(|size| { + let mut __owned_reader = std::io::Read::take( + mt_ser::WrapRead(__reader), size as u64); + let __reader = &mut __owned_reader; + + #code + }) }; } - impl_size!(size8, u8); - impl_size!(size16, u16); - impl_size!(size32, u32); - impl_size!(size64, u64); - - macro_rules! impl_const { - ($name:ident) => { - if let Some(want) = args.$name { - code = quote! { - mt_ser::MtDeserialize::mt_deserialize::(__reader) - .and_then(|got| { - if #want == got { - #code - } else { - Err(mt_ser::DeserializeError::InvalidConst( - #want as u64, got as u64 - )) - } - }) - }; + let impl_const = |value: TokStr| quote! { + { + fn deserialize_same_type( + _: &T, + reader: &mut impl std::io::Read + ) -> Result { + T::mt_deserialize::(reader) + } + + deserialize_same_type(&want, __reader) + .and_then(|got| { + if want == got { + #value + } else { + Err(mt_ser::DeserializeError::InvalidConst( + Box::new(want), Box::new(got) + )) + } + }) + } + }; + + for want in args.const_before.iter().rev() { + let imp = impl_const(code); + code = quote! { + { + let want = #want; + #imp } }; } - impl_const!(const8); - impl_const!(const16); - impl_const!(const32); - impl_const!(const64); + for want in args.const_after.iter() { + let imp = impl_const(quote! { Ok(value) }); + code = quote! { + { + let want = #want; + #code.and_then(|value| { #imp }) + } + }; + } code } @@ -357,8 +325,9 @@ fn serialize_fields(fields: &Fields) -> TokStr { .iter() .map(|(ident, field)| { serialize_args(MtArgs::from_field(field), |args| { - let cfg = get_cfg(args); - quote! { mt_ser::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; } + let def = parse_quote! { mt_ser::DefCfg }; + let len = args.len.as_ref().unwrap_or(&def); + quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; } }) }) .collect() @@ -369,8 +338,9 @@ fn deserialize_fields(fields: &Fields) -> TokStr { .iter() .map(|(ident, field)| { let code = deserialize_args(MtArgs::from_field(field), |args| { - let cfg = get_cfg(args); - let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#cfg>(__reader) }; + let def = parse_quote! { mt_ser::DefCfg }; + let len = args.len.as_ref().unwrap_or(&def); + let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) }; if args.default { code = quote! { diff --git a/src/lib.rs b/src/lib.rs index f57e487..4eaf119 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ use paste::paste as paste_macro; use std::{ collections::{HashMap, HashSet}, convert::Infallible, + fmt::Debug, io::{self, Read, Write}, num::TryFromIntError, ops::Deref, @@ -47,8 +48,8 @@ pub enum DeserializeError { InvalidUtf16(#[from] std::char::DecodeUtf16Error), #[error("invalid {0} enum variant {1}")] InvalidEnumVariant(&'static str, u64), - #[error("invalid constant - wanted: {0} - got: {1}")] - InvalidConst(u64, u64), + #[error("invalid constant - wanted: {0:?} - got: {1:?}")] + InvalidConst(Box, Box), } impl From for DeserializeError { @@ -228,7 +229,7 @@ impl MtLen for () { } } -pub struct Utf16(pub B); +pub struct Utf16(pub B); impl MtCfg for Utf16 { type Len = B::Len; -- 2.44.0