1 use convert_case::{Case, Casing};
2 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
3 use proc_macro::TokenStream;
4 use proc_macro2::TokenStream as TokStr;
5 use quote::{quote, ToTokens};
6 use syn::{parse_macro_input, parse_quote};
8 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
9 #[darling(rename_all = "snake_case")]
15 #[derive(Debug, FromMeta)]
18 repr: Option<syn::Type>,
20 content: Option<String>,
27 fn wrap_attr(attr: &mut syn::Attribute) {
28 let path = attr.path.clone();
29 let tokens = attr.tokens.clone();
31 match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
33 *attr = parse_quote! {
34 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
38 *attr = parse_quote! {
39 #[cfg_attr(feature = "serde", #path #tokens)]
46 #[proc_macro_attribute]
47 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
48 let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
49 let mut input = parse_macro_input!(item as syn::Item);
51 let args = match MacroArgs::from_list(&attr_args) {
54 return TokenStream::from(e.write_errors());
58 let (serializer, deserializer) = match args.to {
59 To::Clt => ("server", "client"),
60 To::Srv => ("client", "server"),
63 let mut out = quote! {
65 #[cfg_attr(feature = "random", derive(GenerateRandom))]
66 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70 ($t:expr, $f:expr) => {
71 $t.iter_mut().for_each($f)
76 syn::Item::Enum(e) => {
77 iter!(e.attrs, wrap_attr);
78 iter!(e.variants, |v| {
79 iter!(v.attrs, wrap_attr);
80 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
85 #[derive(EnumSetType)]
86 #[enumset(serialize_as_map)]
89 if let Some(repr) = args.repr {
90 let repr_str = repr.to_token_stream().to_string();
93 #[enumset(repr = #repr_str)]
95 } else if !args.custom {
96 panic!("missing repr for enum");
102 .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
106 let tag = args.tag.expect("missing tag for enum with payload");
109 #[cfg_attr(feature = "serde", serde(tag = #tag))]
112 if let Some(content) = args.content {
114 #[cfg_attr(feature = "serde", serde(content = #content))]
124 #[derive(Clone, PartialEq)]
129 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
130 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
134 if let Some(repr) = args.repr {
135 if repr == parse_quote! { str } {
137 #[cfg_attr(any(feature = "client", feature = "server"), mt(string_repr))]
144 } else if !args.custom {
145 panic!("missing repr for enum");
150 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
153 syn::Item::Struct(s) => {
154 iter!(s.attrs, wrap_attr);
155 iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
158 #[derive(Clone, PartialEq)]
163 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
164 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
168 _ => panic!("only enum and struct supported"),
171 out.extend(input.to_token_stream());
175 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
176 #[darling(attributes(mt))]
180 const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
182 const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
183 size: Option<syn::Type>, // must implement MtCfg
184 len: Option<syn::Type>, // must implement MtCfg
185 default: bool, // type must implement Default
186 string_repr: bool, // for enums
189 map_ser: Option<syn::Expr>,
190 map_des: Option<syn::Expr>,
191 multiplier: Option<syn::Expr>,
192 typename: Option<syn::Ident>, // remote derive
193 bounds: Option<syn::WhereClause>,
196 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
198 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
200 syn::Fields::Named(fs) => fs
203 .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
205 syn::Fields::Unnamed(fs) => fs
209 .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
211 syn::Fields::Unit => Vec::new(),
215 fn serialize_args(args: &MtArgs, code: &mut TokStr) {
216 macro_rules! impl_compress {
220 let mut __stream = $create;
221 let __writer = &mut __stream;
230 impl_compress!(mt_ser::flate2::write::ZlibEncoder::new(
232 mt_ser::flate2::Compression::default()
237 impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?);
240 if let Some(size) = &args.size {
242 mt_ser::MtSerialize::mt_serialize::<#size>(&{
243 let mut __buf = Vec::new();
244 let __writer = &mut __buf;
251 for x in args.const_before.iter().rev() {
253 #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
258 for x in args.const_after.iter() {
261 #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
266 fn deserialize_args(args: &MtArgs, code: &mut TokStr) {
267 macro_rules! impl_compress {
271 let mut __owned_reader = $create;
272 let __reader = &mut __owned_reader;
281 impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead(
287 impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead(
292 if let Some(size) = &args.size {
295 let __size = #size::mt_deserialize::<DefCfg>(__reader)? as u64;
296 let mut __owned_reader = std::io::Read::take(
297 mt_ser::WrapRead(__reader),
300 let __reader = &mut __owned_reader;
307 let impl_const = |want: &syn::Expr| {
310 fn eq_same_type<T: PartialEq<T>>(a: &T, b: &T) -> bool {
315 let got = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
317 if !eq_same_type(&want, &got) {
318 return Err(mt_ser::DeserializeError::InvalidConst(
319 Box::new(want), Box::new(got)
326 for want in args.const_before.iter().rev() {
327 let imp = impl_const(want);
336 for want in args.const_after.iter() {
337 let imp = impl_const(want);
340 let __result = #code;
348 fn serialize_fields(fields: &Fields) -> TokStr {
351 .map(|(ident, field)| {
352 let args = MtArgs::from_field(field).unwrap();
353 let def = parse_quote! { mt_ser::DefCfg };
354 let len = args.len.as_ref().unwrap_or(&def);
356 let mut code = quote! { #ident };
358 if let Some(multiplier) = &args.multiplier {
360 &((#code) * (#multiplier))
364 if let Some(map) = &args.map_ser {
367 fn call_ser_result<I, O>(
368 f: impl FnOnce(I) -> Result<O, mt_ser::SerializeError>,
370 ) -> Result<O, mt_ser::SerializeError> {
374 &call_ser_result(#map, #code)?
379 code = quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#code, __writer)?; };
381 serialize_args(&args, &mut code);
388 fn deserialize_fields(fields: &Fields) -> TokStr {
391 .map(|(ident, field)| {
392 let args = MtArgs::from_field(field).unwrap();
394 let def = parse_quote! { mt_ser::DefCfg };
395 let len = args.len.as_ref().unwrap_or(&def);
396 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
400 mt_ser::OrDefault::or_default(#code)
408 deserialize_args(&args, &mut code);
410 if let Some(map) = &args.map_des {
413 fn call_des_result<I, O>(
414 f: impl FnOnce(I) -> Result<O, mt_ser::DeserializeError>,
416 ) -> Result<O, mt_ser::DeserializeError> {
420 call_des_result(#map, #code)?
425 if let Some(multiplier) = &args.multiplier {
428 fn div_same_type<D, T: std::ops::Div<D, Output = T>>(a: T, b: D) -> T {
432 div_same_type(#code, #multiplier)
444 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
445 let ident_fn = match input {
446 syn::Fields::Unnamed(_) => |f| {
448 mt_ser::paste::paste! { [<field_ #f>] }
451 _ => |f| quote! { #f },
454 let fields = get_fields(input, ident_fn);
455 let fields_comma: TokStr = fields
457 .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
459 let fields_struct = match input {
460 syn::Fields::Named(_) => quote! { { #fields_comma } },
461 syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
462 syn::Fields::Unit => TokStr::new(),
465 (fields, fields_struct)
468 fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
469 if args.string_repr {
470 parse_quote! { &str }
475 .find(|a| a.path.is_ident("repr"))
476 .expect("missing repr")
478 .expect("invalid repr")
482 fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
483 let mut discr = parse_quote! { 0 };
485 for v in e.variants.iter() {
486 discr = if args.string_repr {
487 let lit = v.ident.to_string().to_case(Case::Snake);
488 parse_quote! { #lit }
490 v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
495 discr = parse_quote! { 1 + #discr };
501 input: &syn::DeriveInput,
502 typename: &syn::Ident,
506 let generics = &input.generics;
507 let bounds = args.bounds.clone().or_else(|| {
508 if generics.params.is_empty() {
516 .rfold(quote! { where }, |before, t| match t {
517 syn::GenericParam::Type(x) => quote! { #before #x: #traitname, },
522 .expect("invalid where clause"),
528 #[automatically_derived]
529 impl #generics #traitname for #typename #generics #bounds { #code }
534 #[proc_macro_derive(MtSerialize, attributes(mt))]
535 pub fn derive_serialize(input: TokenStream) -> TokenStream {
536 let input = parse_macro_input!(input as syn::DeriveInput);
537 let args = MtArgs::from_derive_input(&input).unwrap();
538 let typename = args.typename.as_ref().unwrap_or(&input.ident);
540 let mut code = match &input.data {
541 syn::Data::Enum(e) => {
542 let repr = get_repr(&input, &args);
543 let mut variants = TokStr::new();
545 iter_variants(e, &args, |v, discr| {
546 let args = MtArgs::from_variant(v).unwrap();
548 let (fields, fields_struct) = get_fields_struct(&v.fields);
550 let mut code = serialize_fields(&fields);
551 serialize_args(&args, &mut code);
553 let ident = &v.ident;
555 variants.extend(quote! {
556 #typename::#ident #fields_struct => {
557 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
569 syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
571 panic!("only enum and struct supported");
575 serialize_args(&args, &mut code);
578 quote! { mt_ser::MtSerialize },
583 fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
592 #[proc_macro_derive(MtDeserialize, attributes(mt))]
593 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
594 let input = parse_macro_input!(input as syn::DeriveInput);
595 let args = MtArgs::from_derive_input(&input).unwrap();
596 let typename = args.typename.as_ref().unwrap_or(&input.ident);
598 let mut code = match &input.data {
599 syn::Data::Enum(e) => {
600 let repr = get_repr(&input, &args);
602 let mut consts = TokStr::new();
603 let mut arms = TokStr::new();
605 iter_variants(e, &args, |v, discr| {
606 let args = MtArgs::from_variant(v).unwrap();
608 let ident = &v.ident;
609 let (fields, fields_struct) = get_fields_struct(&v.fields);
610 let mut code = deserialize_fields(&fields);
613 Ok(Self::#ident #fields_struct)
616 deserialize_args(&args, &mut code);
618 consts.extend(quote! {
619 const #ident: #repr = #discr;
627 let type_str = typename.to_string();
628 let discr_match = if args.string_repr {
630 let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
631 match __discr.as_str()
635 let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
645 _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
649 syn::Data::Struct(s) => {
650 let (fields, fields_struct) = get_fields_struct(&s.fields);
651 let code = deserialize_fields(&fields);
655 Ok(Self #fields_struct)
659 panic!("only enum and struct supported");
663 deserialize_args(&args, &mut code);
666 quote! { mt_ser::MtDeserialize },
671 #[allow(non_upper_case_globals)]
672 fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {