1 use convert_case::{Case, Casing};
2 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
3 use proc_macro::TokenStream;
4 use proc_macro2::TokenStream as TokStr;
5 use quote::{quote, ToTokens};
6 use syn::{parse_macro_input, parse_quote};
8 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
9 #[darling(rename_all = "snake_case")]
15 #[derive(Debug, FromMeta)]
18 repr: Option<syn::Type>,
20 content: Option<String>,
27 fn wrap_attr(attr: &mut syn::Attribute) {
28 let path = attr.path.clone();
29 let tokens = attr.tokens.clone();
31 match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
33 *attr = parse_quote! {
34 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
38 *attr = parse_quote! {
39 #[cfg_attr(feature = "serde", #path #tokens)]
46 #[proc_macro_attribute]
47 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
48 let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
49 let mut input = parse_macro_input!(item as syn::Item);
51 let args = match MacroArgs::from_list(&attr_args) {
54 return TokenStream::from(e.write_errors());
58 let (serializer, deserializer) = match args.to {
59 To::Clt => ("server", "client"),
60 To::Srv => ("client", "server"),
63 let mut out = quote! {
65 #[cfg_attr(feature = "random", derive(GenerateRandom))]
66 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70 ($t:expr, $f:expr) => {
71 $t.iter_mut().for_each($f)
76 syn::Item::Enum(e) => {
77 iter!(e.attrs, wrap_attr);
78 iter!(e.variants, |v| {
79 iter!(v.attrs, wrap_attr);
80 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
86 .expect("missing repr for enum")
91 #[derive(EnumSetType)]
92 #[enumset(repr = #repr_str, serialize_as_map)]
98 .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
102 let tag = args.tag.expect("missing tag for enum with payload");
105 #[cfg_attr(feature = "serde", serde(tag = #tag))]
108 if let Some(content) = args.content {
110 #[cfg_attr(feature = "serde", serde(content = #content))]
120 #[derive(Clone, PartialEq)]
125 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
126 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
130 if let Some(repr) = args.repr {
131 if repr == parse_quote! { str } {
133 #[cfg_attr(any(feature = "client", feature = "server"), mt(string_repr))]
140 } else if !args.custom {
141 panic!("missing repr for enum");
146 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
149 syn::Item::Struct(s) => {
150 iter!(s.attrs, wrap_attr);
151 iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
154 #[derive(Clone, PartialEq)]
159 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
160 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
164 _ => panic!("only enum and struct supported"),
167 out.extend(input.to_token_stream());
171 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
172 #[darling(attributes(mt))]
176 const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
179 const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
181 size: Option<syn::Type>, // must implement MtCfg
183 len: Option<syn::Type>, // must implement MtCfg
187 default: bool, // type must implement Default
189 string_repr: bool, // for enums
192 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
194 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
196 syn::Fields::Named(fs) => fs
199 .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
201 syn::Fields::Unnamed(fs) => fs
205 .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
207 syn::Fields::Unit => Vec::new(),
211 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
214 let mut code = body(&args);
219 let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
221 mt_ser::flate2::Compression::default(),
223 let __writer = &mut __stream;
230 if let Some(size) = args.size {
232 mt_ser::MtSerialize::mt_serialize::<#size>(&{
233 let mut __buf = Vec::new();
234 let __writer = &mut __buf;
241 for x in args.const_before.iter().rev() {
243 #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
248 for x in args.const_after.iter() {
251 #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
257 Err(e) => e.write_errors(),
261 fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
264 let mut code = body(&args);
269 let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
270 mt_ser::WrapRead(__reader));
271 let __reader = &mut __owned_reader;
278 if let Some(size) = args.size {
280 #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
281 let mut __owned_reader = std::io::Read::take(
282 mt_ser::WrapRead(__reader), size as u64);
283 let __reader = &mut __owned_reader;
290 let impl_const = |value: TokStr| {
293 fn deserialize_same_type<T: MtDeserialize>(
295 reader: &mut impl std::io::Read
296 ) -> Result<T, mt_ser::DeserializeError> {
297 T::mt_deserialize::<mt_ser::DefCfg>(reader)
300 deserialize_same_type(&want, __reader)
305 Err(mt_ser::DeserializeError::InvalidConst(
306 Box::new(want), Box::new(got)
314 for want in args.const_before.iter().rev() {
315 let imp = impl_const(code);
324 for want in args.const_after.iter() {
325 let imp = impl_const(quote! { Ok(value) });
329 #code.and_then(|value| { #imp })
336 Err(e) => e.write_errors(),
340 fn serialize_fields(fields: &Fields) -> TokStr {
343 .map(|(ident, field)| {
344 serialize_args(MtArgs::from_field(field), |args| {
345 let def = parse_quote! { mt_ser::DefCfg };
346 let len = args.len.as_ref().unwrap_or(&def);
347 quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
353 fn deserialize_fields(fields: &Fields) -> TokStr {
356 .map(|(ident, field)| {
357 let code = deserialize_args(MtArgs::from_field(field), |args| {
358 let def = parse_quote! { mt_ser::DefCfg };
359 let len = args.len.as_ref().unwrap_or(&def);
360 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
364 mt_ser::OrDefault::or_default(#code)
378 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
379 let ident_fn = match input {
380 syn::Fields::Unnamed(_) => |f| {
382 mt_ser::paste::paste! { [<field_ #f>] }
385 _ => |f| quote! { #f },
388 let fields = get_fields(input, ident_fn);
389 let fields_comma: TokStr = fields
391 .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
393 let fields_struct = match input {
394 syn::Fields::Named(_) => quote! { { #fields_comma } },
395 syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
396 syn::Fields::Unit => TokStr::new(),
399 (fields, fields_struct)
402 fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
403 if args.string_repr {
404 parse_quote! { &str }
409 .find(|a| a.path.is_ident("repr"))
410 .expect("missing repr")
412 .expect("invalid repr")
416 fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
417 let mut discr = parse_quote! { 0 };
419 for v in e.variants.iter() {
420 discr = if args.string_repr {
421 let lit = v.ident.to_string().to_case(Case::Snake);
422 parse_quote! { #lit }
424 v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
429 discr = parse_quote! { 1 + #discr };
433 #[proc_macro_derive(MtSerialize, attributes(mt))]
434 pub fn derive_serialize(input: TokenStream) -> TokenStream {
435 let input = parse_macro_input!(input as syn::DeriveInput);
436 let typename = &input.ident;
438 let code = serialize_args(MtArgs::from_derive_input(&input), |args| {
440 syn::Data::Enum(e) => {
441 let repr = get_repr(&input, args);
442 let mut variants = TokStr::new();
444 iter_variants(e, args, |v, discr| {
445 let (fields, fields_struct) = get_fields_struct(&v.fields);
447 serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields));
448 let ident = &v.ident;
450 variants.extend(quote! {
451 #typename::#ident #fields_struct => {
452 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
464 syn::Data::Struct(s) => {
465 serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f }))
468 panic!("only enum and struct supported");
474 #[automatically_derived]
475 impl mt_ser::MtSerialize for #typename {
476 fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
485 #[proc_macro_derive(MtDeserialize, attributes(mt))]
486 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
487 let input = parse_macro_input!(input as syn::DeriveInput);
488 let typename = &input.ident;
490 let code = deserialize_args(MtArgs::from_derive_input(&input), |args| {
492 syn::Data::Enum(e) => {
493 let repr = get_repr(&input, args);
495 let mut consts = TokStr::new();
496 let mut arms = TokStr::new();
498 iter_variants(e, args, |v, discr| {
499 let ident = &v.ident;
500 let (fields, fields_struct) = get_fields_struct(&v.fields);
501 let code = deserialize_args(MtArgs::from_variant(v), |_| {
502 let fields_code = deserialize_fields(&fields);
506 Ok(Self::#ident #fields_struct)
510 consts.extend(quote! {
511 const #ident: #repr = #discr;
519 let type_str = typename.to_string();
520 let discr_match = if args.string_repr {
522 let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
523 match __discr.as_str()
527 let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
537 _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
541 syn::Data::Struct(s) => {
542 let (fields, fields_struct) = get_fields_struct(&s.fields);
543 let code = deserialize_fields(&fields);
547 Ok(Self #fields_struct)
551 panic!("only enum and struct supported");
557 #[automatically_derived]
558 impl mt_ser::MtDeserialize for #typename {
559 #[allow(non_upper_case_globals)]
560 fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {