]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
derive: Only emit mt attribute if client or server feature enabled
[mt_ser.git] / derive / src / lib.rs
1 use convert_case::{Case, Casing};
2 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
3 use proc_macro::TokenStream;
4 use proc_macro2::TokenStream as TokStr;
5 use quote::{quote, ToTokens};
6 use syn::{parse_macro_input, parse_quote};
7
8 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
9 #[darling(rename_all = "snake_case")]
10 enum To {
11     Clt,
12     Srv,
13 }
14
15 #[derive(Debug, FromMeta)]
16 struct MacroArgs {
17     to: To,
18     repr: Option<syn::Type>,
19     tag: Option<String>,
20     content: Option<String>,
21     #[darling(default)]
22     custom: bool,
23     #[darling(default)]
24     enumset: bool,
25 }
26
27 fn wrap_attr(attr: &mut syn::Attribute) {
28     let path = attr.path.clone();
29     let tokens = attr.tokens.clone();
30
31     match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
32         Some("mt") => {
33             *attr = parse_quote! {
34                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
35             };
36         }
37         Some("serde") => {
38             *attr = parse_quote! {
39                 #[cfg_attr(feature = "serde", #path #tokens)]
40             };
41         }
42         _ => {}
43     }
44 }
45
46 #[proc_macro_attribute]
47 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
48     let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
49     let mut input = parse_macro_input!(item as syn::Item);
50
51     let args = match MacroArgs::from_list(&attr_args) {
52         Ok(v) => v,
53         Err(e) => {
54             return TokenStream::from(e.write_errors());
55         }
56     };
57
58     let (serializer, deserializer) = match args.to {
59         To::Clt => ("server", "client"),
60         To::Srv => ("client", "server"),
61     };
62
63     let mut out = quote! {
64         #[derive(Debug)]
65         #[cfg_attr(feature = "random", derive(GenerateRandom))]
66         #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67     };
68
69     macro_rules! iter {
70         ($t:expr, $f:expr) => {
71             $t.iter_mut().for_each($f)
72         };
73     }
74
75     match &mut input {
76         syn::Item::Enum(e) => {
77             iter!(e.attrs, wrap_attr);
78             iter!(e.variants, |v| {
79                 iter!(v.attrs, wrap_attr);
80                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
81             });
82
83             if args.enumset {
84                 let repr_str = args
85                     .repr
86                     .expect("missing repr for enum")
87                     .to_token_stream()
88                     .to_string();
89
90                 out.extend(quote! {
91                     #[derive(EnumSetType)]
92                     #[enumset(repr = #repr_str, serialize_as_map)]
93                 })
94             } else {
95                 let has_payload = e
96                     .variants
97                     .iter()
98                     .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
99                     .is_some();
100
101                 if has_payload {
102                     let tag = args.tag.expect("missing tag for enum with payload");
103
104                     out.extend(quote! {
105                         #[cfg_attr(feature = "serde", serde(tag = #tag))]
106                     });
107
108                     if let Some(content) = args.content {
109                         out.extend(quote! {
110                             #[cfg_attr(feature = "serde", serde(content = #content))]
111                         });
112                     }
113                 } else {
114                     out.extend(quote! {
115                         #[derive(Copy, Eq)]
116                     });
117                 }
118
119                 out.extend(quote! {
120                     #[derive(Clone, PartialEq)]
121                 });
122
123                 if !args.custom {
124                     out.extend(quote! {
125                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
126                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
127                     });
128                 }
129
130                 if let Some(repr) = args.repr {
131                     if repr == parse_quote! { str } {
132                         out.extend(quote! {
133                                                         #[cfg_attr(any(feature = "client", feature = "server"), mt(string_repr))]
134                                                 });
135                     } else {
136                         out.extend(quote! {
137                             #[repr(#repr)]
138                         });
139                     }
140                 } else if !args.custom {
141                     panic!("missing repr for enum");
142                 }
143             }
144
145             out.extend(quote! {
146                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
147             });
148         }
149         syn::Item::Struct(s) => {
150             iter!(s.attrs, wrap_attr);
151             iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
152
153             out.extend(quote! {
154                 #[derive(Clone, PartialEq)]
155             });
156
157             if !args.custom {
158                 out.extend(quote! {
159                     #[cfg_attr(feature = #serializer, derive(MtSerialize))]
160                     #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
161                 });
162             }
163         }
164         _ => panic!("only enum and struct supported"),
165     }
166
167     out.extend(input.to_token_stream());
168     out.into()
169 }
170
171 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
172 #[darling(attributes(mt))]
173 #[darling(default)]
174 struct MtArgs {
175     #[darling(multiple)]
176     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
177
178     #[darling(multiple)]
179     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
180
181     size: Option<syn::Type>, // must implement MtCfg
182
183     len: Option<syn::Type>, // must implement MtCfg
184
185     zlib: bool,
186     zstd: bool,    // TODO
187     default: bool, // type must implement Default
188
189     string_repr: bool, // for enums
190 }
191
192 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
193
194 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
195     match fields {
196         syn::Fields::Named(fs) => fs
197             .named
198             .iter()
199             .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
200             .collect(),
201         syn::Fields::Unnamed(fs) => fs
202             .unnamed
203             .iter()
204             .enumerate()
205             .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
206             .collect(),
207         syn::Fields::Unit => Vec::new(),
208     }
209 }
210
211 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
212     match res {
213         Ok(args) => {
214             let mut code = body(&args);
215
216             if args.zlib {
217                 code = quote! {
218                     let mut __writer = {
219                         let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
220                             __writer,
221                             mt_ser::flate2::Compression::default(),
222                         );
223                         let __writer = &mut __stream;
224                         #code
225                         __stream.finish()?
226                     };
227                 };
228             }
229
230             if let Some(size) = args.size {
231                 code = quote! {
232                     mt_ser::MtSerialize::mt_serialize::<#size>(&{
233                         let mut __buf = Vec::new();
234                         let __writer = &mut __buf;
235                         #code
236                         __buf
237                     }, __writer)?;
238                 };
239             }
240
241             for x in args.const_before.iter().rev() {
242                 code = quote! {
243                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
244                     #code
245                 }
246             }
247
248             for x in args.const_after.iter() {
249                 code = quote! {
250                     #code
251                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
252                 }
253             }
254
255             code
256         }
257         Err(e) => e.write_errors(),
258     }
259 }
260
261 fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
262     match res {
263         Ok(args) => {
264             let mut code = body(&args);
265
266             if args.zlib {
267                 code = quote! {
268                     {
269                         let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
270                             mt_ser::WrapRead(__reader));
271                         let __reader = &mut __owned_reader;
272
273                         #code
274                     }
275                 }
276             }
277
278             if let Some(size) = args.size {
279                 code = quote! {
280                     #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
281                         let mut __owned_reader = std::io::Read::take(
282                             mt_ser::WrapRead(__reader), size as u64);
283                         let __reader = &mut __owned_reader;
284
285                         #code
286                     })
287                 };
288             }
289
290             let impl_const = |value: TokStr| {
291                 quote! {
292                     {
293                         fn deserialize_same_type<T: MtDeserialize>(
294                             _: &T,
295                             reader: &mut impl std::io::Read
296                         ) -> Result<T, mt_ser::DeserializeError> {
297                             T::mt_deserialize::<mt_ser::DefCfg>(reader)
298                         }
299
300                         deserialize_same_type(&want, __reader)
301                             .and_then(|got| {
302                                 if want == got {
303                                     #value
304                                 } else {
305                                     Err(mt_ser::DeserializeError::InvalidConst(
306                                         Box::new(want), Box::new(got)
307                                     ))
308                                 }
309                             })
310                     }
311                 }
312             };
313
314             for want in args.const_before.iter().rev() {
315                 let imp = impl_const(code);
316                 code = quote! {
317                     {
318                         let want = #want;
319                         #imp
320                     }
321                 };
322             }
323
324             for want in args.const_after.iter() {
325                 let imp = impl_const(quote! { Ok(value) });
326                 code = quote! {
327                     {
328                         let want = #want;
329                         #code.and_then(|value| { #imp })
330                     }
331                 };
332             }
333
334             code
335         }
336         Err(e) => e.write_errors(),
337     }
338 }
339
340 fn serialize_fields(fields: &Fields) -> TokStr {
341     fields
342         .iter()
343         .map(|(ident, field)| {
344             serialize_args(MtArgs::from_field(field), |args| {
345                 let def = parse_quote! { mt_ser::DefCfg };
346                 let len = args.len.as_ref().unwrap_or(&def);
347                 quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
348             })
349         })
350         .collect()
351 }
352
353 fn deserialize_fields(fields: &Fields) -> TokStr {
354     fields
355         .iter()
356         .map(|(ident, field)| {
357             let code = deserialize_args(MtArgs::from_field(field), |args| {
358                 let def = parse_quote! { mt_ser::DefCfg };
359                 let len = args.len.as_ref().unwrap_or(&def);
360                 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
361
362                 if args.default {
363                     code = quote! {
364                         mt_ser::OrDefault::or_default(#code)
365                     };
366                 }
367
368                 code
369             });
370
371             quote! {
372                 let #ident = #code?;
373             }
374         })
375         .collect()
376 }
377
378 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
379     let ident_fn = match input {
380         syn::Fields::Unnamed(_) => |f| {
381             quote! {
382                 mt_ser::paste::paste! { [<field_ #f>] }
383             }
384         },
385         _ => |f| quote! { #f },
386     };
387
388     let fields = get_fields(input, ident_fn);
389     let fields_comma: TokStr = fields
390         .iter()
391         .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
392
393     let fields_struct = match input {
394         syn::Fields::Named(_) => quote! { { #fields_comma } },
395         syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
396         syn::Fields::Unit => TokStr::new(),
397     };
398
399     (fields, fields_struct)
400 }
401
402 fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
403     if args.string_repr {
404         parse_quote! { &str }
405     } else {
406         input
407             .attrs
408             .iter()
409             .find(|a| a.path.is_ident("repr"))
410             .expect("missing repr")
411             .parse_args()
412             .expect("invalid repr")
413     }
414 }
415
416 fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
417     let mut discr = parse_quote! { 0 };
418
419     for v in e.variants.iter() {
420         discr = if args.string_repr {
421             let lit = v.ident.to_string().to_case(Case::Snake);
422             parse_quote! { #lit }
423         } else {
424             v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
425         };
426
427         f(v, &discr);
428
429         discr = parse_quote! { 1 + #discr };
430     }
431 }
432
433 #[proc_macro_derive(MtSerialize, attributes(mt))]
434 pub fn derive_serialize(input: TokenStream) -> TokenStream {
435     let input = parse_macro_input!(input as syn::DeriveInput);
436     let typename = &input.ident;
437
438     let code = serialize_args(MtArgs::from_derive_input(&input), |args| {
439         match &input.data {
440             syn::Data::Enum(e) => {
441                 let repr = get_repr(&input, args);
442                 let mut variants = TokStr::new();
443
444                 iter_variants(e, args, |v, discr| {
445                     let (fields, fields_struct) = get_fields_struct(&v.fields);
446                     let code =
447                         serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields));
448                     let ident = &v.ident;
449
450                     variants.extend(quote! {
451                                         #typename::#ident #fields_struct => {
452                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
453                                                 #code
454                                         }
455                                 });
456                 });
457
458                 quote! {
459                     match self {
460                         #variants
461                     }
462                 }
463             }
464             syn::Data::Struct(s) => {
465                 serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f }))
466             }
467             _ => {
468                 panic!("only enum and struct supported");
469             }
470         }
471     });
472
473     quote! {
474                 #[automatically_derived]
475                 impl mt_ser::MtSerialize for #typename {
476                         fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
477                                 #code
478
479                                 Ok(())
480                         }
481                 }
482         }.into()
483 }
484
485 #[proc_macro_derive(MtDeserialize, attributes(mt))]
486 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
487     let input = parse_macro_input!(input as syn::DeriveInput);
488     let typename = &input.ident;
489
490     let code = deserialize_args(MtArgs::from_derive_input(&input), |args| {
491         match &input.data {
492             syn::Data::Enum(e) => {
493                 let repr = get_repr(&input, args);
494
495                 let mut consts = TokStr::new();
496                 let mut arms = TokStr::new();
497
498                 iter_variants(e, args, |v, discr| {
499                     let ident = &v.ident;
500                     let (fields, fields_struct) = get_fields_struct(&v.fields);
501                     let code = deserialize_args(MtArgs::from_variant(v), |_| {
502                         let fields_code = deserialize_fields(&fields);
503
504                         quote! {
505                             #fields_code
506                             Ok(Self::#ident #fields_struct)
507                         }
508                     });
509
510                     consts.extend(quote! {
511                         const #ident: #repr = #discr;
512                     });
513
514                     arms.extend(quote! {
515                         #ident => { #code }
516                     });
517                 });
518
519                 let type_str = typename.to_string();
520                 let discr_match = if args.string_repr {
521                     quote! {
522                         let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
523                         match __discr.as_str()
524                     }
525                 } else {
526                     quote! {
527                         let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
528                         match __discr
529                     }
530                 };
531
532                 quote! {
533                     #consts
534
535                     #discr_match {
536                         #arms
537                         _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
538                     }
539                 }
540             }
541             syn::Data::Struct(s) => {
542                 let (fields, fields_struct) = get_fields_struct(&s.fields);
543                 let code = deserialize_fields(&fields);
544
545                 quote! {
546                     #code
547                     Ok(Self #fields_struct)
548                 }
549             }
550             _ => {
551                 panic!("only enum and struct supported");
552             }
553         }
554     });
555
556     quote! {
557                 #[automatically_derived]
558                 impl mt_ser::MtDeserialize for #typename {
559                         #[allow(non_upper_case_globals)]
560                         fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
561                                 #code
562                         }
563                 }
564         }.into()
565 }