]> git.lizzy.rs Git - mt_ser.git/commitdiff
Allow bounds clause for remote derive
authorLizzy Fleckenstein <eliasfleckenstein@web.de>
Sun, 26 Feb 2023 20:27:16 +0000 (21:27 +0100)
committerLizzy Fleckenstein <eliasfleckenstein@web.de>
Sun, 26 Feb 2023 20:27:16 +0000 (21:27 +0100)
derive/src/lib.rs

index 27af31b3f9457b29d39e9a3b9ec928da74028fb6..0ccc8cacb9b14ff4b97d853ccac577b672b31700 100644 (file)
@@ -187,6 +187,7 @@ struct MtArgs {
     zlib: bool,
     zstd: bool,
     typename: Option<syn::Ident>, // remote derive
+    bounds: Option<syn::WhereClause>,
 }
 
 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
@@ -438,19 +439,46 @@ fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Varia
     }
 }
 
+fn make_impl(
+    traitname: TokStr,
+    input: &syn::DeriveInput,
+    typename: &syn::Ident,
+    args: &MtArgs,
+    code: TokStr,
+) -> TokenStream {
+    let generics = &input.generics;
+    let bounds = args.bounds.clone().or_else(|| {
+        if generics.params.is_empty() {
+            None
+        } else {
+            Some(
+                syn::parse(
+                    generics
+                        .params
+                        .iter()
+                        .rfold(quote! { where }, |before, t| match t {
+                            syn::GenericParam::Type(x) => quote! { #before #x: #traitname, },
+                            _ => before,
+                        })
+                        .into(),
+                )
+                .expect("invalid where clause"),
+            )
+        }
+    });
+
+    quote! {
+        #[automatically_derived]
+        impl #generics #traitname for #typename #generics #bounds { #code }
+    }
+    .into()
+}
+
 #[proc_macro_derive(MtSerialize, attributes(mt))]
 pub fn derive_serialize(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as syn::DeriveInput);
     let args = MtArgs::from_derive_input(&input).unwrap();
     let typename = args.typename.as_ref().unwrap_or(&input.ident);
-    let generics = &input.generics;
-    let mut generics_bounded = generics.clone();
-    for t in generics_bounded.params.iter_mut() {
-        match t {
-            syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtSerialize },
-            _ => {}
-        }
-    }
 
     let mut code = match &input.data {
         syn::Data::Enum(e) => {
@@ -489,16 +517,19 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
 
     serialize_args(&args, &mut code);
 
-    quote! {
-               #[automatically_derived]
-               impl #generics_bounded mt_ser::MtSerialize for #typename #generics {
-                       fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
-                               #code
-
-                               Ok(())
-                       }
-               }
-       }.into()
+    make_impl(
+        quote! { mt_ser::MtSerialize },
+        &input,
+        typename,
+        &args,
+        quote! {
+            fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
+                #code
+
+                Ok(())
+            }
+        },
+    )
 }
 
 #[proc_macro_derive(MtDeserialize, attributes(mt))]
@@ -506,14 +537,6 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as syn::DeriveInput);
     let args = MtArgs::from_derive_input(&input).unwrap();
     let typename = args.typename.as_ref().unwrap_or(&input.ident);
-    let generics = &input.generics;
-    let mut generics_bounded = generics.clone();
-    for t in generics_bounded.params.iter_mut() {
-        match t {
-            syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtDeserialize },
-            _ => {}
-        }
-    }
 
     let mut code = match &input.data {
         syn::Data::Enum(e) => {
@@ -582,13 +605,16 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
 
     deserialize_args(&args, &mut code);
 
-    quote! {
-               #[automatically_derived]
-               impl #generics_bounded mt_ser::MtDeserialize for #typename #generics {
-                       #[allow(non_upper_case_globals)]
-                       fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
-                               #code
-                       }
-               }
-       }.into()
+    make_impl(
+        quote! { mt_ser::MtDeserialize },
+        &input,
+        typename,
+        &args,
+        quote! {
+            #[allow(non_upper_case_globals)]
+            fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
+                #code
+            }
+        },
+    )
 }