]> git.lizzy.rs Git - mt_ser.git/commitdiff
Support Ranges
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Wed, 15 Feb 2023 00:00:37 +0000 (01:00 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Wed, 15 Feb 2023 00:00:37 +0000 (01:00 +0100)
derive/src/lib.rs
src/lib.rs

index d4d89913c4529558884fd57ad51124e16b62fbc7..27af31b3f9457b29d39e9a3b9ec928da74028fb6 100644 (file)
@@ -81,16 +81,20 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
             });
 
             if args.enumset {
-                let repr_str = args
-                    .repr
-                    .expect("missing repr for enum")
-                    .to_token_stream()
-                    .to_string();
-
                 out.extend(quote! {
                     #[derive(EnumSetType)]
-                    #[enumset(repr = #repr_str, serialize_as_map)]
-                })
+                    #[enumset(serialize_as_map)]
+                });
+
+                if let Some(repr) = args.repr {
+                    let repr_str = repr.to_token_stream().to_string();
+
+                    out.extend(quote! {
+                        #[enumset(repr = #repr_str)]
+                    });
+                } else if !args.custom {
+                    panic!("missing repr for enum");
+                }
             } else {
                 let has_payload = e
                     .variants
@@ -182,6 +186,7 @@ struct MtArgs {
     string_repr: bool,       // for enums
     zlib: bool,
     zstd: bool,
+    typename: Option<syn::Ident>, // remote derive
 }
 
 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
@@ -203,155 +208,137 @@ fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields
     }
 }
 
-fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
-    match res {
-        Ok(args) => {
-            let mut code = body(&args);
-
-            macro_rules! impl_compress {
-                ($create:expr) => {
-                    code = quote! {
-                        let mut __writer = {
-                            let mut __stream = $create;
-                            let __writer = &mut __stream;
-                            #code
-                            __stream.finish()?
-                        };
-                    };
+fn serialize_args(args: &MtArgs, code: &mut TokStr) {
+    macro_rules! impl_compress {
+        ($create:expr) => {
+            *code = quote! {
+                let mut __writer = {
+                    let mut __stream = $create;
+                    let __writer = &mut __stream;
+                    #code
+                    __stream.finish()?
                 };
-            }
-
-            if args.zlib {
-                impl_compress!(mt_ser::flate2::write::ZlibEncoder::new(
-                    __writer,
-                    mt_ser::flate2::Compression::default()
-                ));
-            }
+            };
+        };
+    }
 
-            if args.zstd {
-                impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?);
-            }
+    if args.zlib {
+        impl_compress!(mt_ser::flate2::write::ZlibEncoder::new(
+            __writer,
+            mt_ser::flate2::Compression::default()
+        ));
+    }
 
-            if let Some(size) = args.size {
-                code = quote! {
-                    mt_ser::MtSerialize::mt_serialize::<#size>(&{
-                        let mut __buf = Vec::new();
-                        let __writer = &mut __buf;
-                        #code
-                        __buf
-                    }, __writer)?;
-                };
-            }
+    if args.zstd {
+        impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?);
+    }
 
-            for x in args.const_before.iter().rev() {
-                code = quote! {
-                    #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
-                    #code
-                }
-            }
+    if let Some(size) = &args.size {
+        *code = quote! {
+            mt_ser::MtSerialize::mt_serialize::<#size>(&{
+                let mut __buf = Vec::new();
+                let __writer = &mut __buf;
+                #code
+                __buf
+            }, __writer)?;
+        };
+    }
 
-            for x in args.const_after.iter() {
-                code = quote! {
-                    #code
-                    #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
-                }
-            }
+    for x in args.const_before.iter().rev() {
+        *code = quote! {
+            #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
+            #code
+        }
+    }
 
-            code
+    for x in args.const_after.iter() {
+        *code = quote! {
+            #code
+            #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
         }
-        Err(e) => e.write_errors(),
     }
 }
 
-fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
-    match res {
-        Ok(args) => {
-            let mut code = body(&args);
-
-            macro_rules! impl_compress {
-                ($create:expr) => {
-                    code = quote! {
-                        {
-                            let mut __owned_reader = $create;
-                            let __reader = &mut __owned_reader;
+fn deserialize_args(args: &MtArgs, code: &mut TokStr) {
+    macro_rules! impl_compress {
+        ($create:expr) => {
+            *code = quote! {
+                {
+                    let mut __owned_reader = $create;
+                    let __reader = &mut __owned_reader;
 
-                            #code
-                        }
-                    }
-                };
+                    #code
+                }
             }
+        };
+    }
 
-            if args.zlib {
-                impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead(
-                    __reader
-                )));
-            }
+    if args.zlib {
+        impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead(
+            __reader
+        )));
+    }
 
-            if args.zstd {
-                impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead(
-                    __reader
-                ))?);
-            }
+    if args.zstd {
+        impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead(
+            __reader
+        ))?);
+    }
 
-            if let Some(size) = args.size {
-                code = quote! {
-                    #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
-                        let mut __owned_reader = std::io::Read::take(
-                            mt_ser::WrapRead(__reader), size as u64);
-                        let __reader = &mut __owned_reader;
+    if let Some(size) = &args.size {
+        *code = quote! {
+            #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
+                let mut __owned_reader = std::io::Read::take(
+                    mt_ser::WrapRead(__reader), size as u64);
+                let __reader = &mut __owned_reader;
 
-                        #code
-                    })
-                };
-            }
-
-            let impl_const = |value: TokStr| {
-                quote! {
-                    {
-                        fn deserialize_same_type<T: MtDeserialize>(
-                            _: &T,
-                            reader: &mut impl std::io::Read
-                        ) -> Result<T, mt_ser::DeserializeError> {
-                            T::mt_deserialize::<mt_ser::DefCfg>(reader)
-                        }
+                #code
+            })
+        };
+    }
 
-                        deserialize_same_type(&want, __reader)
-                            .and_then(|got| {
-                                if want == got {
-                                    #value
-                                } else {
-                                    Err(mt_ser::DeserializeError::InvalidConst(
-                                        Box::new(want), Box::new(got)
-                                    ))
-                                }
-                            })
-                    }
+    let impl_const = |value: &TokStr| {
+        quote! {
+            {
+                fn deserialize_same_type<T: MtDeserialize>(
+                    _: &T,
+                    reader: &mut impl std::io::Read
+                ) -> Result<T, mt_ser::DeserializeError> {
+                    T::mt_deserialize::<mt_ser::DefCfg>(reader)
                 }
-            };
 
-            for want in args.const_before.iter().rev() {
-                let imp = impl_const(code);
-                code = quote! {
-                    {
-                        let want = #want;
-                        #imp
-                    }
-                };
+                deserialize_same_type(&want, __reader)
+                    .and_then(|got| {
+                        if want == got {
+                            #value
+                        } else {
+                            Err(mt_ser::DeserializeError::InvalidConst(
+                                Box::new(want), Box::new(got)
+                            ))
+                        }
+                    })
             }
+        }
+    };
 
-            for want in args.const_after.iter() {
-                let imp = impl_const(quote! { Ok(value) });
-                code = quote! {
-                    {
-                        let want = #want;
-                        #code.and_then(|value| { #imp })
-                    }
-                };
+    for want in args.const_before.iter().rev() {
+        let imp = impl_const(&code);
+        *code = quote! {
+            {
+                let want = #want;
+                #imp
             }
+        };
+    }
 
-            code
-        }
-        Err(e) => e.write_errors(),
+    for want in args.const_after.iter() {
+        let imp = impl_const(&quote! { Ok(value) });
+        *code = quote! {
+            {
+                let want = #want;
+                #code.and_then(|value| { #imp })
+            }
+        };
     }
 }
 
@@ -359,11 +346,14 @@ fn serialize_fields(fields: &Fields) -> TokStr {
     fields
         .iter()
         .map(|(ident, field)| {
-            serialize_args(MtArgs::from_field(field), |args| {
-                let def = parse_quote! { mt_ser::DefCfg };
-                let len = args.len.as_ref().unwrap_or(&def);
-                quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
-            })
+            let args = MtArgs::from_field(field).unwrap();
+            let def = parse_quote! { mt_ser::DefCfg };
+            let len = args.len.as_ref().unwrap_or(&def);
+
+            let mut code = quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; };
+            serialize_args(&args, &mut code);
+
+            code
         })
         .collect()
 }
@@ -372,19 +362,19 @@ fn deserialize_fields(fields: &Fields) -> TokStr {
     fields
         .iter()
         .map(|(ident, field)| {
-            let code = deserialize_args(MtArgs::from_field(field), |args| {
-                let def = parse_quote! { mt_ser::DefCfg };
-                let len = args.len.as_ref().unwrap_or(&def);
-                let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
-
-                if args.default {
-                    code = quote! {
-                        mt_ser::OrDefault::or_default(#code)
-                    };
-                }
+            let args = MtArgs::from_field(field).unwrap();
 
-                code
-            });
+            let def = parse_quote! { mt_ser::DefCfg };
+            let len = args.len.as_ref().unwrap_or(&def);
+            let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
+
+            if args.default {
+                code = quote! {
+                    mt_ser::OrDefault::or_default(#code)
+                };
+            }
+
+            deserialize_args(&args, &mut code);
 
             quote! {
                 let #ident = #code?;
@@ -451,46 +441,57 @@ fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Varia
 #[proc_macro_derive(MtSerialize, attributes(mt))]
 pub fn derive_serialize(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as syn::DeriveInput);
-    let typename = &input.ident;
+    let args = MtArgs::from_derive_input(&input).unwrap();
+    let typename = args.typename.as_ref().unwrap_or(&input.ident);
+    let generics = &input.generics;
+    let mut generics_bounded = generics.clone();
+    for t in generics_bounded.params.iter_mut() {
+        match t {
+            syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtSerialize },
+            _ => {}
+        }
+    }
 
-    let code = serialize_args(MtArgs::from_derive_input(&input), |args| {
-        match &input.data {
-            syn::Data::Enum(e) => {
-                let repr = get_repr(&input, args);
-                let mut variants = TokStr::new();
+    let mut code = match &input.data {
+        syn::Data::Enum(e) => {
+            let repr = get_repr(&input, &args);
+            let mut variants = TokStr::new();
 
-                iter_variants(e, args, |v, discr| {
-                    let (fields, fields_struct) = get_fields_struct(&v.fields);
-                    let code =
-                        serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields));
-                    let ident = &v.ident;
+            iter_variants(e, &args, |v, discr| {
+                let args = MtArgs::from_variant(v).unwrap();
 
-                    variants.extend(quote! {
+                let (fields, fields_struct) = get_fields_struct(&v.fields);
+
+                let mut code = serialize_fields(&fields);
+                serialize_args(&args, &mut code);
+
+                let ident = &v.ident;
+
+                variants.extend(quote! {
                                        #typename::#ident #fields_struct => {
                                                mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
                                                #code
                                        }
                                });
-                });
+            });
 
-                quote! {
-                    match self {
-                        #variants
-                    }
+            quote! {
+                match self {
+                    #variants
                 }
             }
-            syn::Data::Struct(s) => {
-                serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f }))
-            }
-            _ => {
-                panic!("only enum and struct supported");
-            }
         }
-    });
+        syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
+        _ => {
+            panic!("only enum and struct supported");
+        }
+    };
+
+    serialize_args(&args, &mut code);
 
     quote! {
                #[automatically_derived]
-               impl mt_ser::MtSerialize for #typename {
+               impl #generics_bounded mt_ser::MtSerialize for #typename #generics {
                        fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
                                #code
 
@@ -503,77 +504,87 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
 #[proc_macro_derive(MtDeserialize, attributes(mt))]
 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as syn::DeriveInput);
-    let typename = &input.ident;
+    let args = MtArgs::from_derive_input(&input).unwrap();
+    let typename = args.typename.as_ref().unwrap_or(&input.ident);
+    let generics = &input.generics;
+    let mut generics_bounded = generics.clone();
+    for t in generics_bounded.params.iter_mut() {
+        match t {
+            syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtDeserialize },
+            _ => {}
+        }
+    }
 
-    let code = deserialize_args(MtArgs::from_derive_input(&input), |args| {
-        match &input.data {
-            syn::Data::Enum(e) => {
-                let repr = get_repr(&input, args);
+    let mut code = match &input.data {
+        syn::Data::Enum(e) => {
+            let repr = get_repr(&input, &args);
 
-                let mut consts = TokStr::new();
-                let mut arms = TokStr::new();
+            let mut consts = TokStr::new();
+            let mut arms = TokStr::new();
 
-                iter_variants(e, args, |v, discr| {
-                    let ident = &v.ident;
-                    let (fields, fields_struct) = get_fields_struct(&v.fields);
-                    let code = deserialize_args(MtArgs::from_variant(v), |_| {
-                        let fields_code = deserialize_fields(&fields);
+            iter_variants(e, &args, |v, discr| {
+                let args = MtArgs::from_variant(v).unwrap();
 
-                        quote! {
-                            #fields_code
-                            Ok(Self::#ident #fields_struct)
-                        }
-                    });
+                let ident = &v.ident;
+                let (fields, fields_struct) = get_fields_struct(&v.fields);
+                let mut code = deserialize_fields(&fields);
+                code = quote! {
+                    #code
+                    Ok(Self::#ident #fields_struct)
+                };
 
-                    consts.extend(quote! {
-                        const #ident: #repr = #discr;
-                    });
+                deserialize_args(&args, &mut code);
 
-                    arms.extend(quote! {
-                        #ident => { #code }
-                    });
+                consts.extend(quote! {
+                    const #ident: #repr = #discr;
                 });
 
-                let type_str = typename.to_string();
-                let discr_match = if args.string_repr {
-                    quote! {
-                        let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
-                        match __discr.as_str()
-                    }
-                } else {
-                    quote! {
-                        let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
-                        match __discr
-                    }
-                };
+                arms.extend(quote! {
+                    #ident => { #code }
+                });
+            });
 
+            let type_str = typename.to_string();
+            let discr_match = if args.string_repr {
                 quote! {
-                    #consts
-
-                    #discr_match {
-                        #arms
-                        _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
-                    }
+                    let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
+                    match __discr.as_str()
                 }
-            }
-            syn::Data::Struct(s) => {
-                let (fields, fields_struct) = get_fields_struct(&s.fields);
-                let code = deserialize_fields(&fields);
-
+            } else {
                 quote! {
-                    #code
-                    Ok(Self #fields_struct)
+                    let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
+                    match __discr
+                }
+            };
+
+            quote! {
+                #consts
+
+                #discr_match {
+                    #arms
+                    _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
                 }
             }
-            _ => {
-                panic!("only enum and struct supported");
+        }
+        syn::Data::Struct(s) => {
+            let (fields, fields_struct) = get_fields_struct(&s.fields);
+            let code = deserialize_fields(&fields);
+
+            quote! {
+                #code
+                Ok(Self #fields_struct)
             }
         }
-    });
+        _ => {
+            panic!("only enum and struct supported");
+        }
+    };
+
+    deserialize_args(&args, &mut code);
 
     quote! {
                #[automatically_derived]
-               impl mt_ser::MtDeserialize for #typename {
+               impl #generics_bounded mt_ser::MtDeserialize for #typename #generics {
                        #[allow(non_upper_case_globals)]
                        fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
                                #code
index 961e4169b549c5042c78eed2d3a115ce080bd0da..b2aebc87e8b1e3bb634b6cdcdf5b51d460679bc0 100644 (file)
@@ -16,13 +16,15 @@ use std::{
     fmt::Debug,
     io::{self, Read, Write},
     num::TryFromIntError,
-    ops::Deref,
+    ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
 };
 use thiserror::Error;
 
 #[cfg(test)]
 mod tests;
 
+use crate as mt_ser;
+
 #[derive(Error, Debug)]
 pub enum SerializeError {
     #[error("io error: {0}")]
@@ -559,3 +561,56 @@ impl<T: MtDeserialize> MtDeserialize for Box<T> {
         Ok(Self::new(T::mt_deserialize::<C>(reader)?))
     }
 }
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "Range")]
+#[allow(unused)]
+struct RemoteRange<T> {
+    start: T,
+    end: T,
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeFrom")]
+#[allow(unused)]
+struct RemoteRangeFrom<T> {
+    start: T,
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeFull")]
+#[allow(unused)]
+struct RemoteRangeFull;
+
+// RangeInclusive fields are private
+impl<T: MtSerialize> MtSerialize for RangeInclusive<T> {
+    fn mt_serialize<C: MtCfg>(&self, writer: &mut impl Write) -> Result<(), SerializeError> {
+        self.start().mt_serialize::<DefCfg>(writer)?;
+        self.end().mt_serialize::<DefCfg>(writer)?;
+
+        Ok(())
+    }
+}
+
+impl<T: MtDeserialize> MtDeserialize for RangeInclusive<T> {
+    fn mt_deserialize<C: MtCfg>(reader: &mut impl Read) -> Result<Self, DeserializeError> {
+        let start = T::mt_deserialize::<DefCfg>(reader)?;
+        let end = T::mt_deserialize::<DefCfg>(reader)?;
+
+        Ok(start..=end)
+    }
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeTo")]
+#[allow(unused)]
+struct RemoteRangeTo<T> {
+    end: T,
+}
+
+#[derive(MtSerialize, MtDeserialize)]
+#[mt(typename = "RangeToInclusive")]
+#[allow(unused)]
+struct RemoteRangeToInclusive<T> {
+    end: T,
+}