From: Lizzy Fleckenstein Date: Wed, 15 Feb 2023 00:00:37 +0000 (+0100) Subject: Support Ranges X-Git-Url: https://git.lizzy.rs/?a=commitdiff_plain;h=3561472c60acfeda2dadb477a1da4afd287ab30a;p=mt_ser.git Support Ranges --- diff --git a/derive/src/lib.rs b/derive/src/lib.rs index d4d8991..27af31b 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -81,16 +81,20 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream { }); if args.enumset { - let repr_str = args - .repr - .expect("missing repr for enum") - .to_token_stream() - .to_string(); - out.extend(quote! { #[derive(EnumSetType)] - #[enumset(repr = #repr_str, serialize_as_map)] - }) + #[enumset(serialize_as_map)] + }); + + if let Some(repr) = args.repr { + let repr_str = repr.to_token_stream().to_string(); + + out.extend(quote! { + #[enumset(repr = #repr_str)] + }); + } else if !args.custom { + panic!("missing repr for enum"); + } } else { let has_payload = e .variants @@ -182,6 +186,7 @@ struct MtArgs { string_repr: bool, // for enums zlib: bool, zstd: bool, + typename: Option, // remote derive } type Fields<'a> = Vec<(TokStr, &'a syn::Field)>; @@ -203,155 +208,137 @@ fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields } } -fn serialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr { - match res { - Ok(args) => { - let mut code = body(&args); - - macro_rules! impl_compress { - ($create:expr) => { - code = quote! { - let mut __writer = { - let mut __stream = $create; - let __writer = &mut __stream; - #code - __stream.finish()? - }; - }; +fn serialize_args(args: &MtArgs, code: &mut TokStr) { + macro_rules! impl_compress { + ($create:expr) => { + *code = quote! { + let mut __writer = { + let mut __stream = $create; + let __writer = &mut __stream; + #code + __stream.finish()? }; - } - - if args.zlib { - impl_compress!(mt_ser::flate2::write::ZlibEncoder::new( - __writer, - mt_ser::flate2::Compression::default() - )); - } + }; + }; + } - if args.zstd { - impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?); - } + if args.zlib { + impl_compress!(mt_ser::flate2::write::ZlibEncoder::new( + __writer, + mt_ser::flate2::Compression::default() + )); + } - if let Some(size) = args.size { - code = quote! { - mt_ser::MtSerialize::mt_serialize::<#size>(&{ - let mut __buf = Vec::new(); - let __writer = &mut __buf; - #code - __buf - }, __writer)?; - }; - } + if args.zstd { + impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?); + } - for x in args.const_before.iter().rev() { - code = quote! { - #x.mt_serialize::(__writer)?; - #code - } - } + if let Some(size) = &args.size { + *code = quote! { + mt_ser::MtSerialize::mt_serialize::<#size>(&{ + let mut __buf = Vec::new(); + let __writer = &mut __buf; + #code + __buf + }, __writer)?; + }; + } - for x in args.const_after.iter() { - code = quote! { - #code - #x.mt_serialize::(__writer)?; - } - } + for x in args.const_before.iter().rev() { + *code = quote! { + #x.mt_serialize::(__writer)?; + #code + } + } - code + for x in args.const_after.iter() { + *code = quote! { + #code + #x.mt_serialize::(__writer)?; } - Err(e) => e.write_errors(), } } -fn deserialize_args(res: darling::Result, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr { - match res { - Ok(args) => { - let mut code = body(&args); - - macro_rules! impl_compress { - ($create:expr) => { - code = quote! { - { - let mut __owned_reader = $create; - let __reader = &mut __owned_reader; +fn deserialize_args(args: &MtArgs, code: &mut TokStr) { + macro_rules! impl_compress { + ($create:expr) => { + *code = quote! { + { + let mut __owned_reader = $create; + let __reader = &mut __owned_reader; - #code - } - } - }; + #code + } } + }; + } - if args.zlib { - impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead( - __reader - ))); - } + if args.zlib { + impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead( + __reader + ))); + } - if args.zstd { - impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead( - __reader - ))?); - } + if args.zstd { + impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead( + __reader + ))?); + } - if let Some(size) = args.size { - code = quote! { - #size::mt_deserialize::(__reader).and_then(|size| { - let mut __owned_reader = std::io::Read::take( - mt_ser::WrapRead(__reader), size as u64); - let __reader = &mut __owned_reader; + if let Some(size) = &args.size { + *code = quote! { + #size::mt_deserialize::(__reader).and_then(|size| { + let mut __owned_reader = std::io::Read::take( + mt_ser::WrapRead(__reader), size as u64); + let __reader = &mut __owned_reader; - #code - }) - }; - } - - let impl_const = |value: TokStr| { - quote! { - { - fn deserialize_same_type( - _: &T, - reader: &mut impl std::io::Read - ) -> Result { - T::mt_deserialize::(reader) - } + #code + }) + }; + } - deserialize_same_type(&want, __reader) - .and_then(|got| { - if want == got { - #value - } else { - Err(mt_ser::DeserializeError::InvalidConst( - Box::new(want), Box::new(got) - )) - } - }) - } + let impl_const = |value: &TokStr| { + quote! { + { + fn deserialize_same_type( + _: &T, + reader: &mut impl std::io::Read + ) -> Result { + T::mt_deserialize::(reader) } - }; - for want in args.const_before.iter().rev() { - let imp = impl_const(code); - code = quote! { - { - let want = #want; - #imp - } - }; + deserialize_same_type(&want, __reader) + .and_then(|got| { + if want == got { + #value + } else { + Err(mt_ser::DeserializeError::InvalidConst( + Box::new(want), Box::new(got) + )) + } + }) } + } + }; - for want in args.const_after.iter() { - let imp = impl_const(quote! { Ok(value) }); - code = quote! { - { - let want = #want; - #code.and_then(|value| { #imp }) - } - }; + for want in args.const_before.iter().rev() { + let imp = impl_const(&code); + *code = quote! { + { + let want = #want; + #imp } + }; + } - code - } - Err(e) => e.write_errors(), + for want in args.const_after.iter() { + let imp = impl_const("e! { Ok(value) }); + *code = quote! { + { + let want = #want; + #code.and_then(|value| { #imp }) + } + }; } } @@ -359,11 +346,14 @@ fn serialize_fields(fields: &Fields) -> TokStr { fields .iter() .map(|(ident, field)| { - serialize_args(MtArgs::from_field(field), |args| { - let def = parse_quote! { mt_ser::DefCfg }; - let len = args.len.as_ref().unwrap_or(&def); - quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; } - }) + let args = MtArgs::from_field(field).unwrap(); + let def = parse_quote! { mt_ser::DefCfg }; + let len = args.len.as_ref().unwrap_or(&def); + + let mut code = quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }; + serialize_args(&args, &mut code); + + code }) .collect() } @@ -372,19 +362,19 @@ fn deserialize_fields(fields: &Fields) -> TokStr { fields .iter() .map(|(ident, field)| { - let code = deserialize_args(MtArgs::from_field(field), |args| { - let def = parse_quote! { mt_ser::DefCfg }; - let len = args.len.as_ref().unwrap_or(&def); - let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) }; - - if args.default { - code = quote! { - mt_ser::OrDefault::or_default(#code) - }; - } + let args = MtArgs::from_field(field).unwrap(); - code - }); + let def = parse_quote! { mt_ser::DefCfg }; + let len = args.len.as_ref().unwrap_or(&def); + let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) }; + + if args.default { + code = quote! { + mt_ser::OrDefault::or_default(#code) + }; + } + + deserialize_args(&args, &mut code); quote! { let #ident = #code?; @@ -451,46 +441,57 @@ fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Varia #[proc_macro_derive(MtSerialize, attributes(mt))] pub fn derive_serialize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); - let typename = &input.ident; + let args = MtArgs::from_derive_input(&input).unwrap(); + let typename = args.typename.as_ref().unwrap_or(&input.ident); + let generics = &input.generics; + let mut generics_bounded = generics.clone(); + for t in generics_bounded.params.iter_mut() { + match t { + syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtSerialize }, + _ => {} + } + } - let code = serialize_args(MtArgs::from_derive_input(&input), |args| { - match &input.data { - syn::Data::Enum(e) => { - let repr = get_repr(&input, args); - let mut variants = TokStr::new(); + let mut code = match &input.data { + syn::Data::Enum(e) => { + let repr = get_repr(&input, &args); + let mut variants = TokStr::new(); - iter_variants(e, args, |v, discr| { - let (fields, fields_struct) = get_fields_struct(&v.fields); - let code = - serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields)); - let ident = &v.ident; + iter_variants(e, &args, |v, discr| { + let args = MtArgs::from_variant(v).unwrap(); - variants.extend(quote! { + let (fields, fields_struct) = get_fields_struct(&v.fields); + + let mut code = serialize_fields(&fields); + serialize_args(&args, &mut code); + + let ident = &v.ident; + + variants.extend(quote! { #typename::#ident #fields_struct => { mt_ser::MtSerialize::mt_serialize::(&((#discr) as #repr), __writer)?; #code } }); - }); + }); - quote! { - match self { - #variants - } + quote! { + match self { + #variants } } - syn::Data::Struct(s) => { - serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })) - } - _ => { - panic!("only enum and struct supported"); - } } - }); + syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })), + _ => { + panic!("only enum and struct supported"); + } + }; + + serialize_args(&args, &mut code); quote! { #[automatically_derived] - impl mt_ser::MtSerialize for #typename { + impl #generics_bounded mt_ser::MtSerialize for #typename #generics { fn mt_serialize(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> { #code @@ -503,77 +504,87 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { #[proc_macro_derive(MtDeserialize, attributes(mt))] pub fn derive_deserialize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); - let typename = &input.ident; + let args = MtArgs::from_derive_input(&input).unwrap(); + let typename = args.typename.as_ref().unwrap_or(&input.ident); + let generics = &input.generics; + let mut generics_bounded = generics.clone(); + for t in generics_bounded.params.iter_mut() { + match t { + syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtDeserialize }, + _ => {} + } + } - let code = deserialize_args(MtArgs::from_derive_input(&input), |args| { - match &input.data { - syn::Data::Enum(e) => { - let repr = get_repr(&input, args); + let mut code = match &input.data { + syn::Data::Enum(e) => { + let repr = get_repr(&input, &args); - let mut consts = TokStr::new(); - let mut arms = TokStr::new(); + let mut consts = TokStr::new(); + let mut arms = TokStr::new(); - iter_variants(e, args, |v, discr| { - let ident = &v.ident; - let (fields, fields_struct) = get_fields_struct(&v.fields); - let code = deserialize_args(MtArgs::from_variant(v), |_| { - let fields_code = deserialize_fields(&fields); + iter_variants(e, &args, |v, discr| { + let args = MtArgs::from_variant(v).unwrap(); - quote! { - #fields_code - Ok(Self::#ident #fields_struct) - } - }); + let ident = &v.ident; + let (fields, fields_struct) = get_fields_struct(&v.fields); + let mut code = deserialize_fields(&fields); + code = quote! { + #code + Ok(Self::#ident #fields_struct) + }; - consts.extend(quote! { - const #ident: #repr = #discr; - }); + deserialize_args(&args, &mut code); - arms.extend(quote! { - #ident => { #code } - }); + consts.extend(quote! { + const #ident: #repr = #discr; }); - let type_str = typename.to_string(); - let discr_match = if args.string_repr { - quote! { - let __discr = String::mt_deserialize::(__reader)?; - match __discr.as_str() - } - } else { - quote! { - let __discr = mt_ser::MtDeserialize::mt_deserialize::(__reader)?; - match __discr - } - }; + arms.extend(quote! { + #ident => { #code } + }); + }); + let type_str = typename.to_string(); + let discr_match = if args.string_repr { quote! { - #consts - - #discr_match { - #arms - _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr))) - } + let __discr = String::mt_deserialize::(__reader)?; + match __discr.as_str() } - } - syn::Data::Struct(s) => { - let (fields, fields_struct) = get_fields_struct(&s.fields); - let code = deserialize_fields(&fields); - + } else { quote! { - #code - Ok(Self #fields_struct) + let __discr = mt_ser::MtDeserialize::mt_deserialize::(__reader)?; + match __discr + } + }; + + quote! { + #consts + + #discr_match { + #arms + _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr))) } } - _ => { - panic!("only enum and struct supported"); + } + syn::Data::Struct(s) => { + let (fields, fields_struct) = get_fields_struct(&s.fields); + let code = deserialize_fields(&fields); + + quote! { + #code + Ok(Self #fields_struct) } } - }); + _ => { + panic!("only enum and struct supported"); + } + }; + + deserialize_args(&args, &mut code); quote! { #[automatically_derived] - impl mt_ser::MtDeserialize for #typename { + impl #generics_bounded mt_ser::MtDeserialize for #typename #generics { #[allow(non_upper_case_globals)] fn mt_deserialize(__reader: &mut impl std::io::Read) -> Result { #code diff --git a/src/lib.rs b/src/lib.rs index 961e416..b2aebc8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,13 +16,15 @@ use std::{ fmt::Debug, io::{self, Read, Write}, num::TryFromIntError, - ops::Deref, + ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}, }; use thiserror::Error; #[cfg(test)] mod tests; +use crate as mt_ser; + #[derive(Error, Debug)] pub enum SerializeError { #[error("io error: {0}")] @@ -559,3 +561,56 @@ impl MtDeserialize for Box { Ok(Self::new(T::mt_deserialize::(reader)?)) } } + +#[derive(MtSerialize, MtDeserialize)] +#[mt(typename = "Range")] +#[allow(unused)] +struct RemoteRange { + start: T, + end: T, +} + +#[derive(MtSerialize, MtDeserialize)] +#[mt(typename = "RangeFrom")] +#[allow(unused)] +struct RemoteRangeFrom { + start: T, +} + +#[derive(MtSerialize, MtDeserialize)] +#[mt(typename = "RangeFull")] +#[allow(unused)] +struct RemoteRangeFull; + +// RangeInclusive fields are private +impl MtSerialize for RangeInclusive { + fn mt_serialize(&self, writer: &mut impl Write) -> Result<(), SerializeError> { + self.start().mt_serialize::(writer)?; + self.end().mt_serialize::(writer)?; + + Ok(()) + } +} + +impl MtDeserialize for RangeInclusive { + fn mt_deserialize(reader: &mut impl Read) -> Result { + let start = T::mt_deserialize::(reader)?; + let end = T::mt_deserialize::(reader)?; + + Ok(start..=end) + } +} + +#[derive(MtSerialize, MtDeserialize)] +#[mt(typename = "RangeTo")] +#[allow(unused)] +struct RemoteRangeTo { + end: T, +} + +#[derive(MtSerialize, MtDeserialize)] +#[mt(typename = "RangeToInclusive")] +#[allow(unused)] +struct RemoteRangeToInclusive { + end: T, +}