]> git.lizzy.rs Git - mt_ser.git/blobdiff - derive/src/lib.rs
Implement zstd compression
[mt_ser.git] / derive / src / lib.rs
index 5b80762ae6134fc468353b43ea58327ff6129c11..d4d89913c4529558884fd57ad51124e16b62fbc7 100644 (file)
@@ -174,19 +174,14 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
 struct MtArgs {
     #[darling(multiple)]
     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
-
     #[darling(multiple)]
     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
-
     size: Option<syn::Type>, // must implement MtCfg
-
-    len: Option<syn::Type>, // must implement MtCfg
-
+    len: Option<syn::Type>,  // must implement MtCfg
+    default: bool,           // type must implement Default
+    string_repr: bool,       // for enums
     zlib: bool,
-    zstd: bool,    // TODO
-    default: bool, // type must implement Default
-
-    string_repr: bool, // for enums
+    zstd: bool,
 }
 
 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
@@ -213,20 +208,30 @@ fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> To
         Ok(args) => {
             let mut code = body(&args);
 
-            if args.zlib {
-                code = quote! {
-                    let mut __writer = {
-                        let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
-                            __writer,
-                            mt_ser::flate2::Compression::default(),
-                        );
-                        let __writer = &mut __stream;
-                        #code
-                        __stream.finish()?
+            macro_rules! impl_compress {
+                ($create:expr) => {
+                    code = quote! {
+                        let mut __writer = {
+                            let mut __stream = $create;
+                            let __writer = &mut __stream;
+                            #code
+                            __stream.finish()?
+                        };
                     };
                 };
             }
 
+            if args.zlib {
+                impl_compress!(mt_ser::flate2::write::ZlibEncoder::new(
+                    __writer,
+                    mt_ser::flate2::Compression::default()
+                ));
+            }
+
+            if args.zstd {
+                impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?);
+            }
+
             if let Some(size) = args.size {
                 code = quote! {
                     mt_ser::MtSerialize::mt_serialize::<#size>(&{
@@ -263,16 +268,29 @@ fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) ->
         Ok(args) => {
             let mut code = body(&args);
 
-            if args.zlib {
-                code = quote! {
-                    {
-                        let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
-                            mt_ser::WrapRead(__reader));
-                        let __reader = &mut __owned_reader;
+            macro_rules! impl_compress {
+                ($create:expr) => {
+                    code = quote! {
+                        {
+                            let mut __owned_reader = $create;
+                            let __reader = &mut __owned_reader;
 
-                        #code
+                            #code
+                        }
                     }
-                }
+                };
+            }
+
+            if args.zlib {
+                impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead(
+                    __reader
+                )));
+            }
+
+            if args.zstd {
+                impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead(
+                    __reader
+                ))?);
             }
 
             if let Some(size) = args.size {