]> git.lizzy.rs Git - mt_ser.git/commitdiff
Add multiplier and maps master
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Mon, 27 Feb 2023 20:43:56 +0000 (21:43 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Mon, 27 Feb 2023 20:43:56 +0000 (21:43 +0100)
derive/src/lib.rs

index 0ccc8cacb9b14ff4b97d853ccac577b672b31700..9ad2c7a9bccbc5e8d99317765c28c056bc2af7ac 100644 (file)
@@ -186,6 +186,9 @@ struct MtArgs {
     string_repr: bool,       // for enums
     zlib: bool,
     zstd: bool,
+    map_ser: Option<syn::Expr>,
+    map_des: Option<syn::Expr>,
+    multiplier: Option<syn::Expr>,
     typename: Option<syn::Ident>, // remote derive
     bounds: Option<syn::WhereClause>,
 }
@@ -288,56 +291,55 @@ fn deserialize_args(args: &MtArgs, code: &mut TokStr) {
 
     if let Some(size) = &args.size {
         *code = quote! {
-            #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
+            {
+                let __size = #size::mt_deserialize::<DefCfg>(__reader)? as u64;
                 let mut __owned_reader = std::io::Read::take(
-                    mt_ser::WrapRead(__reader), size as u64);
+                    mt_ser::WrapRead(__reader),
+                    __size,
+                );
                 let __reader = &mut __owned_reader;
 
                 #code
-            })
+            }
         };
     }
 
-    let impl_const = |value: &TokStr| {
+    let impl_const = |want: &syn::Expr| {
         quote! {
             {
-                fn deserialize_same_type<T: MtDeserialize>(
-                    _: &T,
-                    reader: &mut impl std::io::Read
-                ) -> Result<T, mt_ser::DeserializeError> {
-                    T::mt_deserialize::<mt_ser::DefCfg>(reader)
+                fn eq_same_type<T: PartialEq<T>>(a: &T, b: &T) -> bool {
+                    a == b
                 }
 
-                deserialize_same_type(&want, __reader)
-                    .and_then(|got| {
-                        if want == got {
-                            #value
-                        } else {
-                            Err(mt_ser::DeserializeError::InvalidConst(
-                                Box::new(want), Box::new(got)
-                            ))
-                        }
-                    })
+                let want = #want;
+                let got = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
+
+                if !eq_same_type(&want, &got) {
+                    return Err(mt_ser::DeserializeError::InvalidConst(
+                        Box::new(want), Box::new(got)
+                    ));
+                }
             }
         }
     };
 
     for want in args.const_before.iter().rev() {
-        let imp = impl_const(&code);
+        let imp = impl_const(want);
         *code = quote! {
             {
-                let want = #want;
                 #imp
+                #code
             }
         };
     }
 
     for want in args.const_after.iter() {
-        let imp = impl_const(&quote! { Ok(value) });
+        let imp = impl_const(want);
         *code = quote! {
             {
-                let want = #want;
-                #code.and_then(|value| { #imp })
+                let __result = #code;
+                #imp
+                __result
             }
         };
     }
@@ -351,7 +353,31 @@ fn serialize_fields(fields: &Fields) -> TokStr {
             let def = parse_quote! { mt_ser::DefCfg };
             let len = args.len.as_ref().unwrap_or(&def);
 
-            let mut code = quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; };
+            let mut code = quote! { #ident };
+
+            if let Some(multiplier) = &args.multiplier {
+                code = quote! {
+                    &((#code) * (#multiplier))
+                };
+            }
+
+            if let Some(map) = &args.map_ser {
+                code = quote! {
+                    {
+                        fn call_ser_result<I, O>(
+                            f: impl FnOnce(I) -> Result<O, mt_ser::SerializeError>,
+                            i: I
+                        ) -> Result<O, mt_ser::SerializeError> {
+                            f(i)
+                        }
+
+                        &call_ser_result(#map, #code)?
+                    }
+                };
+            }
+
+            code = quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#code, __writer)?; };
+
             serialize_args(&args, &mut code);
 
             code
@@ -375,10 +401,41 @@ fn deserialize_fields(fields: &Fields) -> TokStr {
                 };
             }
 
+            code = quote! {
+                (#code)?
+            };
+
             deserialize_args(&args, &mut code);
 
+            if let Some(map) = &args.map_des {
+                code = quote! {
+                    {
+                        fn call_des_result<I, O>(
+                            f: impl FnOnce(I) -> Result<O, mt_ser::DeserializeError>,
+                            i: I
+                        ) -> Result<O, mt_ser::DeserializeError> {
+                            f(i)
+                        }
+
+                        call_des_result(#map, #code)?
+                    }
+                };
+            }
+
+            if let Some(multiplier) = &args.multiplier {
+                code = quote! {
+                    {
+                        fn div_same_type<D, T: std::ops::Div<D, Output = T>>(a: T, b: D) -> T {
+                            a / b
+                        }
+
+                        div_same_type(#code, #multiplier)
+                    }
+                }
+            }
+
             quote! {
-                let #ident = #code?;
+                let #ident = #code;
             }
         })
         .collect()