]> git.lizzy.rs Git - mt_ser.git/blobdiff - derive/src/lib.rs
Support string repr
[mt_ser.git] / derive / src / lib.rs
index 82d964421b8628e624ebc18d2ed49ce2b86c8929..ae9842c2e249ce711d6fa68d420d1374cea94b4c 100644 (file)
@@ -1,3 +1,4 @@
+use convert_case::{Case, Casing};
 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
 use proc_macro::TokenStream;
 use proc_macro2::TokenStream as TokStr;
@@ -117,14 +118,6 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
                     });
                 }
 
-                if let Some(repr) = args.repr {
-                    out.extend(quote! {
-                        #[repr(#repr)]
-                    });
-                } else if !args.custom {
-                    panic!("missing repr for enum");
-                }
-
                 out.extend(quote! {
                     #[derive(Clone, PartialEq)]
                 });
@@ -135,6 +128,20 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
                     });
                 }
+
+                if let Some(repr) = args.repr {
+                    if repr == parse_quote! { str } {
+                        out.extend(quote! {
+                            #[mt(string_repr)]
+                        });
+                    } else {
+                        out.extend(quote! {
+                            #[repr(#repr)]
+                        });
+                    }
+                } else if !args.custom {
+                    panic!("missing repr for enum");
+                }
             }
 
             out.extend(quote! {
@@ -180,6 +187,8 @@ struct MtArgs {
     zlib: bool,
     zstd: bool,    // TODO
     default: bool, // type must implement Default
+
+    string_repr: bool, // for enums
 }
 
 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
@@ -392,14 +401,35 @@ fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
     (fields, fields_struct)
 }
 
-fn get_repr(input: &syn::DeriveInput) -> syn::Type {
-    input
-        .attrs
-        .iter()
-        .find(|a| a.path.is_ident("repr"))
-        .expect("missing repr")
-        .parse_args()
-        .expect("invalid repr")
+fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
+    if args.string_repr {
+        parse_quote! { &str }
+    } else {
+        input
+            .attrs
+            .iter()
+            .find(|a| a.path.is_ident("repr"))
+            .expect("missing repr")
+            .parse_args()
+            .expect("invalid repr")
+    }
+}
+
+fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
+    let mut discr = parse_quote! { 0 };
+
+    for v in e.variants.iter() {
+        discr = if args.string_repr {
+            let lit = v.ident.to_string().to_case(Case::Snake);
+            parse_quote! { #lit }
+        } else {
+            v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
+        };
+
+        f(&v, &discr);
+
+        discr = parse_quote! { 1 + #discr };
+    }
 }
 
 #[proc_macro_derive(MtSerialize, attributes(mt))]
@@ -407,40 +437,38 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as syn::DeriveInput);
     let typename = &input.ident;
 
-    let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
-        syn::Data::Enum(e) => {
-            let repr = get_repr(&input);
-            let variants: TokStr = e.variants
-                               .iter()
-                               .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| {
-                                       let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
-                                       let (fields, fields_struct) = get_fields_struct(&v.fields);
-
-                                       let code = serialize_args(MtArgs::from_variant(v), |_|
-                                               serialize_fields(&fields));
-                                       let variant = &v.ident;
-
-                                       (
-                                               parse_quote! { 1 + #discr },
-                                               quote! {
-                                                       #before
-                                                       #typename::#variant #fields_struct => {
-                                                               mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
-                                                               #code
-                                                       }
-                                               }
-                                       )
-                               }).1;
+    let code = serialize_args(MtArgs::from_derive_input(&input), |args| {
+        match &input.data {
+            syn::Data::Enum(e) => {
+                let repr = get_repr(&input, &args);
+                let mut variants = TokStr::new();
+
+                iter_variants(&e, &args, |v, discr| {
+                    let (fields, fields_struct) = get_fields_struct(&v.fields);
+                    let code =
+                        serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields));
+                    let ident = &v.ident;
+
+                    variants.extend(quote! {
+                                       #typename::#ident #fields_struct => {
+                                               mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
+                                               #code
+                                       }
+                               });
+                });
 
-            quote! {
-                match self {
-                    #variants
+                quote! {
+                    match self {
+                        #variants
+                    }
                 }
             }
-        }
-        syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
-        _ => {
-            panic!("only enum and struct supported");
+            syn::Data::Struct(s) => {
+                serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f }))
+            }
+            _ => {
+                panic!("only enum and struct supported");
+            }
         }
     });
 
@@ -461,60 +489,69 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as syn::DeriveInput);
     let typename = &input.ident;
 
-    let code = deserialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
-        syn::Data::Enum(e) => {
-            let repr = get_repr(&input);
-            let type_str = typename.to_string();
+    let code = deserialize_args(MtArgs::from_derive_input(&input), |args| {
+        match &input.data {
+            syn::Data::Enum(e) => {
+                let repr = get_repr(&input, &args);
 
-            let mut consts = TokStr::new();
-            let mut arms = TokStr::new();
-            let mut discr = parse_quote! { 0 };
+                let mut consts = TokStr::new();
+                let mut arms = TokStr::new();
 
-            for v in e.variants.iter() {
-                discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
+                iter_variants(&e, &args, |v, discr| {
+                    let ident = &v.ident;
+                    let (fields, fields_struct) = get_fields_struct(&v.fields);
+                    let code = deserialize_args(MtArgs::from_variant(v), |_| {
+                        let fields_code = deserialize_fields(&fields);
 
-                let ident = &v.ident;
-                let (fields, fields_struct) = get_fields_struct(&v.fields);
-                let code = deserialize_args(MtArgs::from_variant(v), |_| {
-                    let fields_code = deserialize_fields(&fields);
-
-                    quote! {
-                        #fields_code
-                        Ok(Self::#ident #fields_struct)
-                    }
-                });
+                        quote! {
+                            #fields_code
+                            Ok(Self::#ident #fields_struct)
+                        }
+                    });
 
-                consts.extend(quote! {
-                    const #ident: #repr = #discr;
-                });
+                    consts.extend(quote! {
+                        const #ident: #repr = #discr;
+                    });
 
-                arms.extend(quote! {
-                    #ident => { #code }
+                    arms.extend(quote! {
+                        #ident => { #code }
+                    });
                 });
 
-                discr = parse_quote! { 1 + #discr };
-            }
+                let type_str = typename.to_string();
+                let discr_match = if args.string_repr {
+                    quote! {
+                        let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
+                        match __discr.as_str()
+                    }
+                } else {
+                    quote! {
+                        let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
+                        match __discr
+                    }
+                };
 
-            quote! {
-                #consts
+                quote! {
+                    #consts
 
-                match mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)? {
-                    #arms
-                    x => Err(mt_ser::DeserializeError::InvalidEnumVariant(#type_str, x as u64))
+                    #discr_match {
+                        #arms
+                        _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
+                    }
                 }
             }
-        }
-        syn::Data::Struct(s) => {
-            let (fields, fields_struct) = get_fields_struct(&s.fields);
-            let code = deserialize_fields(&fields);
+            syn::Data::Struct(s) => {
+                let (fields, fields_struct) = get_fields_struct(&s.fields);
+                let code = deserialize_fields(&fields);
 
-            quote! {
-                #code
-                Ok(Self #fields_struct)
+                quote! {
+                    #code
+                    Ok(Self #fields_struct)
+                }
+            }
+            _ => {
+                panic!("only enum and struct supported");
             }
-        }
-        _ => {
-            panic!("only enum and struct supported");
         }
     });