]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
Add multiplier and maps
[mt_ser.git] / derive / src / lib.rs
1 use convert_case::{Case, Casing};
2 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
3 use proc_macro::TokenStream;
4 use proc_macro2::TokenStream as TokStr;
5 use quote::{quote, ToTokens};
6 use syn::{parse_macro_input, parse_quote};
7
8 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
9 #[darling(rename_all = "snake_case")]
10 enum To {
11     Clt,
12     Srv,
13 }
14
15 #[derive(Debug, FromMeta)]
16 struct MacroArgs {
17     to: To,
18     repr: Option<syn::Type>,
19     tag: Option<String>,
20     content: Option<String>,
21     #[darling(default)]
22     custom: bool,
23     #[darling(default)]
24     enumset: bool,
25 }
26
27 fn wrap_attr(attr: &mut syn::Attribute) {
28     let path = attr.path.clone();
29     let tokens = attr.tokens.clone();
30
31     match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
32         Some("mt") => {
33             *attr = parse_quote! {
34                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
35             };
36         }
37         Some("serde") => {
38             *attr = parse_quote! {
39                 #[cfg_attr(feature = "serde", #path #tokens)]
40             };
41         }
42         _ => {}
43     }
44 }
45
46 #[proc_macro_attribute]
47 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
48     let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
49     let mut input = parse_macro_input!(item as syn::Item);
50
51     let args = match MacroArgs::from_list(&attr_args) {
52         Ok(v) => v,
53         Err(e) => {
54             return TokenStream::from(e.write_errors());
55         }
56     };
57
58     let (serializer, deserializer) = match args.to {
59         To::Clt => ("server", "client"),
60         To::Srv => ("client", "server"),
61     };
62
63     let mut out = quote! {
64         #[derive(Debug)]
65         #[cfg_attr(feature = "random", derive(GenerateRandom))]
66         #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67     };
68
69     macro_rules! iter {
70         ($t:expr, $f:expr) => {
71             $t.iter_mut().for_each($f)
72         };
73     }
74
75     match &mut input {
76         syn::Item::Enum(e) => {
77             iter!(e.attrs, wrap_attr);
78             iter!(e.variants, |v| {
79                 iter!(v.attrs, wrap_attr);
80                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
81             });
82
83             if args.enumset {
84                 out.extend(quote! {
85                     #[derive(EnumSetType)]
86                     #[enumset(serialize_as_map)]
87                 });
88
89                 if let Some(repr) = args.repr {
90                     let repr_str = repr.to_token_stream().to_string();
91
92                     out.extend(quote! {
93                         #[enumset(repr = #repr_str)]
94                     });
95                 } else if !args.custom {
96                     panic!("missing repr for enum");
97                 }
98             } else {
99                 let has_payload = e
100                     .variants
101                     .iter()
102                     .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
103                     .is_some();
104
105                 if has_payload {
106                     let tag = args.tag.expect("missing tag for enum with payload");
107
108                     out.extend(quote! {
109                         #[cfg_attr(feature = "serde", serde(tag = #tag))]
110                     });
111
112                     if let Some(content) = args.content {
113                         out.extend(quote! {
114                             #[cfg_attr(feature = "serde", serde(content = #content))]
115                         });
116                     }
117                 } else {
118                     out.extend(quote! {
119                         #[derive(Copy, Eq)]
120                     });
121                 }
122
123                 out.extend(quote! {
124                     #[derive(Clone, PartialEq)]
125                 });
126
127                 if !args.custom {
128                     out.extend(quote! {
129                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
130                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
131                     });
132                 }
133
134                 if let Some(repr) = args.repr {
135                     if repr == parse_quote! { str } {
136                         out.extend(quote! {
137                                                         #[cfg_attr(any(feature = "client", feature = "server"), mt(string_repr))]
138                                                 });
139                     } else {
140                         out.extend(quote! {
141                             #[repr(#repr)]
142                         });
143                     }
144                 } else if !args.custom {
145                     panic!("missing repr for enum");
146                 }
147             }
148
149             out.extend(quote! {
150                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
151             });
152         }
153         syn::Item::Struct(s) => {
154             iter!(s.attrs, wrap_attr);
155             iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
156
157             out.extend(quote! {
158                 #[derive(Clone, PartialEq)]
159             });
160
161             if !args.custom {
162                 out.extend(quote! {
163                     #[cfg_attr(feature = #serializer, derive(MtSerialize))]
164                     #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
165                 });
166             }
167         }
168         _ => panic!("only enum and struct supported"),
169     }
170
171     out.extend(input.to_token_stream());
172     out.into()
173 }
174
175 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
176 #[darling(attributes(mt))]
177 #[darling(default)]
178 struct MtArgs {
179     #[darling(multiple)]
180     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
181     #[darling(multiple)]
182     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
183     size: Option<syn::Type>, // must implement MtCfg
184     len: Option<syn::Type>,  // must implement MtCfg
185     default: bool,           // type must implement Default
186     string_repr: bool,       // for enums
187     zlib: bool,
188     zstd: bool,
189     map_ser: Option<syn::Expr>,
190     map_des: Option<syn::Expr>,
191     multiplier: Option<syn::Expr>,
192     typename: Option<syn::Ident>, // remote derive
193     bounds: Option<syn::WhereClause>,
194 }
195
196 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
197
198 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
199     match fields {
200         syn::Fields::Named(fs) => fs
201             .named
202             .iter()
203             .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
204             .collect(),
205         syn::Fields::Unnamed(fs) => fs
206             .unnamed
207             .iter()
208             .enumerate()
209             .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
210             .collect(),
211         syn::Fields::Unit => Vec::new(),
212     }
213 }
214
215 fn serialize_args(args: &MtArgs, code: &mut TokStr) {
216     macro_rules! impl_compress {
217         ($create:expr) => {
218             *code = quote! {
219                 let mut __writer = {
220                     let mut __stream = $create;
221                     let __writer = &mut __stream;
222                     #code
223                     __stream.finish()?
224                 };
225             };
226         };
227     }
228
229     if args.zlib {
230         impl_compress!(mt_ser::flate2::write::ZlibEncoder::new(
231             __writer,
232             mt_ser::flate2::Compression::default()
233         ));
234     }
235
236     if args.zstd {
237         impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?);
238     }
239
240     if let Some(size) = &args.size {
241         *code = quote! {
242             mt_ser::MtSerialize::mt_serialize::<#size>(&{
243                 let mut __buf = Vec::new();
244                 let __writer = &mut __buf;
245                 #code
246                 __buf
247             }, __writer)?;
248         };
249     }
250
251     for x in args.const_before.iter().rev() {
252         *code = quote! {
253             #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
254             #code
255         }
256     }
257
258     for x in args.const_after.iter() {
259         *code = quote! {
260             #code
261             #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
262         }
263     }
264 }
265
266 fn deserialize_args(args: &MtArgs, code: &mut TokStr) {
267     macro_rules! impl_compress {
268         ($create:expr) => {
269             *code = quote! {
270                 {
271                     let mut __owned_reader = $create;
272                     let __reader = &mut __owned_reader;
273
274                     #code
275                 }
276             }
277         };
278     }
279
280     if args.zlib {
281         impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead(
282             __reader
283         )));
284     }
285
286     if args.zstd {
287         impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead(
288             __reader
289         ))?);
290     }
291
292     if let Some(size) = &args.size {
293         *code = quote! {
294             {
295                 let __size = #size::mt_deserialize::<DefCfg>(__reader)? as u64;
296                 let mut __owned_reader = std::io::Read::take(
297                     mt_ser::WrapRead(__reader),
298                     __size,
299                 );
300                 let __reader = &mut __owned_reader;
301
302                 #code
303             }
304         };
305     }
306
307     let impl_const = |want: &syn::Expr| {
308         quote! {
309             {
310                 fn eq_same_type<T: PartialEq<T>>(a: &T, b: &T) -> bool {
311                     a == b
312                 }
313
314                 let want = #want;
315                 let got = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
316
317                 if !eq_same_type(&want, &got) {
318                     return Err(mt_ser::DeserializeError::InvalidConst(
319                         Box::new(want), Box::new(got)
320                     ));
321                 }
322             }
323         }
324     };
325
326     for want in args.const_before.iter().rev() {
327         let imp = impl_const(want);
328         *code = quote! {
329             {
330                 #imp
331                 #code
332             }
333         };
334     }
335
336     for want in args.const_after.iter() {
337         let imp = impl_const(want);
338         *code = quote! {
339             {
340                 let __result = #code;
341                 #imp
342                 __result
343             }
344         };
345     }
346 }
347
348 fn serialize_fields(fields: &Fields) -> TokStr {
349     fields
350         .iter()
351         .map(|(ident, field)| {
352             let args = MtArgs::from_field(field).unwrap();
353             let def = parse_quote! { mt_ser::DefCfg };
354             let len = args.len.as_ref().unwrap_or(&def);
355
356             let mut code = quote! { #ident };
357
358             if let Some(multiplier) = &args.multiplier {
359                 code = quote! {
360                     &((#code) * (#multiplier))
361                 };
362             }
363
364             if let Some(map) = &args.map_ser {
365                 code = quote! {
366                     {
367                         fn call_ser_result<I, O>(
368                             f: impl FnOnce(I) -> Result<O, mt_ser::SerializeError>,
369                             i: I
370                         ) -> Result<O, mt_ser::SerializeError> {
371                             f(i)
372                         }
373
374                         &call_ser_result(#map, #code)?
375                     }
376                 };
377             }
378
379             code = quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#code, __writer)?; };
380
381             serialize_args(&args, &mut code);
382
383             code
384         })
385         .collect()
386 }
387
388 fn deserialize_fields(fields: &Fields) -> TokStr {
389     fields
390         .iter()
391         .map(|(ident, field)| {
392             let args = MtArgs::from_field(field).unwrap();
393
394             let def = parse_quote! { mt_ser::DefCfg };
395             let len = args.len.as_ref().unwrap_or(&def);
396             let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
397
398             if args.default {
399                 code = quote! {
400                     mt_ser::OrDefault::or_default(#code)
401                 };
402             }
403
404             code = quote! {
405                 (#code)?
406             };
407
408             deserialize_args(&args, &mut code);
409
410             if let Some(map) = &args.map_des {
411                 code = quote! {
412                     {
413                         fn call_des_result<I, O>(
414                             f: impl FnOnce(I) -> Result<O, mt_ser::DeserializeError>,
415                             i: I
416                         ) -> Result<O, mt_ser::DeserializeError> {
417                             f(i)
418                         }
419
420                         call_des_result(#map, #code)?
421                     }
422                 };
423             }
424
425             if let Some(multiplier) = &args.multiplier {
426                 code = quote! {
427                     {
428                         fn div_same_type<D, T: std::ops::Div<D, Output = T>>(a: T, b: D) -> T {
429                             a / b
430                         }
431
432                         div_same_type(#code, #multiplier)
433                     }
434                 }
435             }
436
437             quote! {
438                 let #ident = #code;
439             }
440         })
441         .collect()
442 }
443
444 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
445     let ident_fn = match input {
446         syn::Fields::Unnamed(_) => |f| {
447             quote! {
448                 mt_ser::paste::paste! { [<field_ #f>] }
449             }
450         },
451         _ => |f| quote! { #f },
452     };
453
454     let fields = get_fields(input, ident_fn);
455     let fields_comma: TokStr = fields
456         .iter()
457         .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
458
459     let fields_struct = match input {
460         syn::Fields::Named(_) => quote! { { #fields_comma } },
461         syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
462         syn::Fields::Unit => TokStr::new(),
463     };
464
465     (fields, fields_struct)
466 }
467
468 fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
469     if args.string_repr {
470         parse_quote! { &str }
471     } else {
472         input
473             .attrs
474             .iter()
475             .find(|a| a.path.is_ident("repr"))
476             .expect("missing repr")
477             .parse_args()
478             .expect("invalid repr")
479     }
480 }
481
482 fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
483     let mut discr = parse_quote! { 0 };
484
485     for v in e.variants.iter() {
486         discr = if args.string_repr {
487             let lit = v.ident.to_string().to_case(Case::Snake);
488             parse_quote! { #lit }
489         } else {
490             v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
491         };
492
493         f(v, &discr);
494
495         discr = parse_quote! { 1 + #discr };
496     }
497 }
498
499 fn make_impl(
500     traitname: TokStr,
501     input: &syn::DeriveInput,
502     typename: &syn::Ident,
503     args: &MtArgs,
504     code: TokStr,
505 ) -> TokenStream {
506     let generics = &input.generics;
507     let bounds = args.bounds.clone().or_else(|| {
508         if generics.params.is_empty() {
509             None
510         } else {
511             Some(
512                 syn::parse(
513                     generics
514                         .params
515                         .iter()
516                         .rfold(quote! { where }, |before, t| match t {
517                             syn::GenericParam::Type(x) => quote! { #before #x: #traitname, },
518                             _ => before,
519                         })
520                         .into(),
521                 )
522                 .expect("invalid where clause"),
523             )
524         }
525     });
526
527     quote! {
528         #[automatically_derived]
529         impl #generics #traitname for #typename #generics #bounds { #code }
530     }
531     .into()
532 }
533
534 #[proc_macro_derive(MtSerialize, attributes(mt))]
535 pub fn derive_serialize(input: TokenStream) -> TokenStream {
536     let input = parse_macro_input!(input as syn::DeriveInput);
537     let args = MtArgs::from_derive_input(&input).unwrap();
538     let typename = args.typename.as_ref().unwrap_or(&input.ident);
539
540     let mut code = match &input.data {
541         syn::Data::Enum(e) => {
542             let repr = get_repr(&input, &args);
543             let mut variants = TokStr::new();
544
545             iter_variants(e, &args, |v, discr| {
546                 let args = MtArgs::from_variant(v).unwrap();
547
548                 let (fields, fields_struct) = get_fields_struct(&v.fields);
549
550                 let mut code = serialize_fields(&fields);
551                 serialize_args(&args, &mut code);
552
553                 let ident = &v.ident;
554
555                 variants.extend(quote! {
556                                         #typename::#ident #fields_struct => {
557                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
558                                                 #code
559                                         }
560                                 });
561             });
562
563             quote! {
564                 match self {
565                     #variants
566                 }
567             }
568         }
569         syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
570         _ => {
571             panic!("only enum and struct supported");
572         }
573     };
574
575     serialize_args(&args, &mut code);
576
577     make_impl(
578         quote! { mt_ser::MtSerialize },
579         &input,
580         typename,
581         &args,
582         quote! {
583             fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
584                 #code
585
586                 Ok(())
587             }
588         },
589     )
590 }
591
592 #[proc_macro_derive(MtDeserialize, attributes(mt))]
593 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
594     let input = parse_macro_input!(input as syn::DeriveInput);
595     let args = MtArgs::from_derive_input(&input).unwrap();
596     let typename = args.typename.as_ref().unwrap_or(&input.ident);
597
598     let mut code = match &input.data {
599         syn::Data::Enum(e) => {
600             let repr = get_repr(&input, &args);
601
602             let mut consts = TokStr::new();
603             let mut arms = TokStr::new();
604
605             iter_variants(e, &args, |v, discr| {
606                 let args = MtArgs::from_variant(v).unwrap();
607
608                 let ident = &v.ident;
609                 let (fields, fields_struct) = get_fields_struct(&v.fields);
610                 let mut code = deserialize_fields(&fields);
611                 code = quote! {
612                     #code
613                     Ok(Self::#ident #fields_struct)
614                 };
615
616                 deserialize_args(&args, &mut code);
617
618                 consts.extend(quote! {
619                     const #ident: #repr = #discr;
620                 });
621
622                 arms.extend(quote! {
623                     #ident => { #code }
624                 });
625             });
626
627             let type_str = typename.to_string();
628             let discr_match = if args.string_repr {
629                 quote! {
630                     let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
631                     match __discr.as_str()
632                 }
633             } else {
634                 quote! {
635                     let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
636                     match __discr
637                 }
638             };
639
640             quote! {
641                 #consts
642
643                 #discr_match {
644                     #arms
645                     _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
646                 }
647             }
648         }
649         syn::Data::Struct(s) => {
650             let (fields, fields_struct) = get_fields_struct(&s.fields);
651             let code = deserialize_fields(&fields);
652
653             quote! {
654                 #code
655                 Ok(Self #fields_struct)
656             }
657         }
658         _ => {
659             panic!("only enum and struct supported");
660         }
661     };
662
663     deserialize_args(&args, &mut code);
664
665     make_impl(
666         quote! { mt_ser::MtDeserialize },
667         &input,
668         typename,
669         &args,
670         quote! {
671             #[allow(non_upper_case_globals)]
672             fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
673                 #code
674             }
675         },
676     )
677 }