]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
Implement zstd compression
[mt_ser.git] / derive / src / lib.rs
1 use convert_case::{Case, Casing};
2 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
3 use proc_macro::TokenStream;
4 use proc_macro2::TokenStream as TokStr;
5 use quote::{quote, ToTokens};
6 use syn::{parse_macro_input, parse_quote};
7
8 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
9 #[darling(rename_all = "snake_case")]
10 enum To {
11     Clt,
12     Srv,
13 }
14
15 #[derive(Debug, FromMeta)]
16 struct MacroArgs {
17     to: To,
18     repr: Option<syn::Type>,
19     tag: Option<String>,
20     content: Option<String>,
21     #[darling(default)]
22     custom: bool,
23     #[darling(default)]
24     enumset: bool,
25 }
26
27 fn wrap_attr(attr: &mut syn::Attribute) {
28     let path = attr.path.clone();
29     let tokens = attr.tokens.clone();
30
31     match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
32         Some("mt") => {
33             *attr = parse_quote! {
34                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
35             };
36         }
37         Some("serde") => {
38             *attr = parse_quote! {
39                 #[cfg_attr(feature = "serde", #path #tokens)]
40             };
41         }
42         _ => {}
43     }
44 }
45
46 #[proc_macro_attribute]
47 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
48     let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
49     let mut input = parse_macro_input!(item as syn::Item);
50
51     let args = match MacroArgs::from_list(&attr_args) {
52         Ok(v) => v,
53         Err(e) => {
54             return TokenStream::from(e.write_errors());
55         }
56     };
57
58     let (serializer, deserializer) = match args.to {
59         To::Clt => ("server", "client"),
60         To::Srv => ("client", "server"),
61     };
62
63     let mut out = quote! {
64         #[derive(Debug)]
65         #[cfg_attr(feature = "random", derive(GenerateRandom))]
66         #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67     };
68
69     macro_rules! iter {
70         ($t:expr, $f:expr) => {
71             $t.iter_mut().for_each($f)
72         };
73     }
74
75     match &mut input {
76         syn::Item::Enum(e) => {
77             iter!(e.attrs, wrap_attr);
78             iter!(e.variants, |v| {
79                 iter!(v.attrs, wrap_attr);
80                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
81             });
82
83             if args.enumset {
84                 let repr_str = args
85                     .repr
86                     .expect("missing repr for enum")
87                     .to_token_stream()
88                     .to_string();
89
90                 out.extend(quote! {
91                     #[derive(EnumSetType)]
92                     #[enumset(repr = #repr_str, serialize_as_map)]
93                 })
94             } else {
95                 let has_payload = e
96                     .variants
97                     .iter()
98                     .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
99                     .is_some();
100
101                 if has_payload {
102                     let tag = args.tag.expect("missing tag for enum with payload");
103
104                     out.extend(quote! {
105                         #[cfg_attr(feature = "serde", serde(tag = #tag))]
106                     });
107
108                     if let Some(content) = args.content {
109                         out.extend(quote! {
110                             #[cfg_attr(feature = "serde", serde(content = #content))]
111                         });
112                     }
113                 } else {
114                     out.extend(quote! {
115                         #[derive(Copy, Eq)]
116                     });
117                 }
118
119                 out.extend(quote! {
120                     #[derive(Clone, PartialEq)]
121                 });
122
123                 if !args.custom {
124                     out.extend(quote! {
125                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
126                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
127                     });
128                 }
129
130                 if let Some(repr) = args.repr {
131                     if repr == parse_quote! { str } {
132                         out.extend(quote! {
133                                                         #[cfg_attr(any(feature = "client", feature = "server"), mt(string_repr))]
134                                                 });
135                     } else {
136                         out.extend(quote! {
137                             #[repr(#repr)]
138                         });
139                     }
140                 } else if !args.custom {
141                     panic!("missing repr for enum");
142                 }
143             }
144
145             out.extend(quote! {
146                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
147             });
148         }
149         syn::Item::Struct(s) => {
150             iter!(s.attrs, wrap_attr);
151             iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
152
153             out.extend(quote! {
154                 #[derive(Clone, PartialEq)]
155             });
156
157             if !args.custom {
158                 out.extend(quote! {
159                     #[cfg_attr(feature = #serializer, derive(MtSerialize))]
160                     #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
161                 });
162             }
163         }
164         _ => panic!("only enum and struct supported"),
165     }
166
167     out.extend(input.to_token_stream());
168     out.into()
169 }
170
171 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
172 #[darling(attributes(mt))]
173 #[darling(default)]
174 struct MtArgs {
175     #[darling(multiple)]
176     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
177     #[darling(multiple)]
178     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
179     size: Option<syn::Type>, // must implement MtCfg
180     len: Option<syn::Type>,  // must implement MtCfg
181     default: bool,           // type must implement Default
182     string_repr: bool,       // for enums
183     zlib: bool,
184     zstd: bool,
185 }
186
187 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
188
189 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
190     match fields {
191         syn::Fields::Named(fs) => fs
192             .named
193             .iter()
194             .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
195             .collect(),
196         syn::Fields::Unnamed(fs) => fs
197             .unnamed
198             .iter()
199             .enumerate()
200             .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
201             .collect(),
202         syn::Fields::Unit => Vec::new(),
203     }
204 }
205
206 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
207     match res {
208         Ok(args) => {
209             let mut code = body(&args);
210
211             macro_rules! impl_compress {
212                 ($create:expr) => {
213                     code = quote! {
214                         let mut __writer = {
215                             let mut __stream = $create;
216                             let __writer = &mut __stream;
217                             #code
218                             __stream.finish()?
219                         };
220                     };
221                 };
222             }
223
224             if args.zlib {
225                 impl_compress!(mt_ser::flate2::write::ZlibEncoder::new(
226                     __writer,
227                     mt_ser::flate2::Compression::default()
228                 ));
229             }
230
231             if args.zstd {
232                 impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?);
233             }
234
235             if let Some(size) = args.size {
236                 code = quote! {
237                     mt_ser::MtSerialize::mt_serialize::<#size>(&{
238                         let mut __buf = Vec::new();
239                         let __writer = &mut __buf;
240                         #code
241                         __buf
242                     }, __writer)?;
243                 };
244             }
245
246             for x in args.const_before.iter().rev() {
247                 code = quote! {
248                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
249                     #code
250                 }
251             }
252
253             for x in args.const_after.iter() {
254                 code = quote! {
255                     #code
256                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
257                 }
258             }
259
260             code
261         }
262         Err(e) => e.write_errors(),
263     }
264 }
265
266 fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
267     match res {
268         Ok(args) => {
269             let mut code = body(&args);
270
271             macro_rules! impl_compress {
272                 ($create:expr) => {
273                     code = quote! {
274                         {
275                             let mut __owned_reader = $create;
276                             let __reader = &mut __owned_reader;
277
278                             #code
279                         }
280                     }
281                 };
282             }
283
284             if args.zlib {
285                 impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead(
286                     __reader
287                 )));
288             }
289
290             if args.zstd {
291                 impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead(
292                     __reader
293                 ))?);
294             }
295
296             if let Some(size) = args.size {
297                 code = quote! {
298                     #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
299                         let mut __owned_reader = std::io::Read::take(
300                             mt_ser::WrapRead(__reader), size as u64);
301                         let __reader = &mut __owned_reader;
302
303                         #code
304                     })
305                 };
306             }
307
308             let impl_const = |value: TokStr| {
309                 quote! {
310                     {
311                         fn deserialize_same_type<T: MtDeserialize>(
312                             _: &T,
313                             reader: &mut impl std::io::Read
314                         ) -> Result<T, mt_ser::DeserializeError> {
315                             T::mt_deserialize::<mt_ser::DefCfg>(reader)
316                         }
317
318                         deserialize_same_type(&want, __reader)
319                             .and_then(|got| {
320                                 if want == got {
321                                     #value
322                                 } else {
323                                     Err(mt_ser::DeserializeError::InvalidConst(
324                                         Box::new(want), Box::new(got)
325                                     ))
326                                 }
327                             })
328                     }
329                 }
330             };
331
332             for want in args.const_before.iter().rev() {
333                 let imp = impl_const(code);
334                 code = quote! {
335                     {
336                         let want = #want;
337                         #imp
338                     }
339                 };
340             }
341
342             for want in args.const_after.iter() {
343                 let imp = impl_const(quote! { Ok(value) });
344                 code = quote! {
345                     {
346                         let want = #want;
347                         #code.and_then(|value| { #imp })
348                     }
349                 };
350             }
351
352             code
353         }
354         Err(e) => e.write_errors(),
355     }
356 }
357
358 fn serialize_fields(fields: &Fields) -> TokStr {
359     fields
360         .iter()
361         .map(|(ident, field)| {
362             serialize_args(MtArgs::from_field(field), |args| {
363                 let def = parse_quote! { mt_ser::DefCfg };
364                 let len = args.len.as_ref().unwrap_or(&def);
365                 quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
366             })
367         })
368         .collect()
369 }
370
371 fn deserialize_fields(fields: &Fields) -> TokStr {
372     fields
373         .iter()
374         .map(|(ident, field)| {
375             let code = deserialize_args(MtArgs::from_field(field), |args| {
376                 let def = parse_quote! { mt_ser::DefCfg };
377                 let len = args.len.as_ref().unwrap_or(&def);
378                 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
379
380                 if args.default {
381                     code = quote! {
382                         mt_ser::OrDefault::or_default(#code)
383                     };
384                 }
385
386                 code
387             });
388
389             quote! {
390                 let #ident = #code?;
391             }
392         })
393         .collect()
394 }
395
396 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
397     let ident_fn = match input {
398         syn::Fields::Unnamed(_) => |f| {
399             quote! {
400                 mt_ser::paste::paste! { [<field_ #f>] }
401             }
402         },
403         _ => |f| quote! { #f },
404     };
405
406     let fields = get_fields(input, ident_fn);
407     let fields_comma: TokStr = fields
408         .iter()
409         .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
410
411     let fields_struct = match input {
412         syn::Fields::Named(_) => quote! { { #fields_comma } },
413         syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
414         syn::Fields::Unit => TokStr::new(),
415     };
416
417     (fields, fields_struct)
418 }
419
420 fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
421     if args.string_repr {
422         parse_quote! { &str }
423     } else {
424         input
425             .attrs
426             .iter()
427             .find(|a| a.path.is_ident("repr"))
428             .expect("missing repr")
429             .parse_args()
430             .expect("invalid repr")
431     }
432 }
433
434 fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
435     let mut discr = parse_quote! { 0 };
436
437     for v in e.variants.iter() {
438         discr = if args.string_repr {
439             let lit = v.ident.to_string().to_case(Case::Snake);
440             parse_quote! { #lit }
441         } else {
442             v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
443         };
444
445         f(v, &discr);
446
447         discr = parse_quote! { 1 + #discr };
448     }
449 }
450
451 #[proc_macro_derive(MtSerialize, attributes(mt))]
452 pub fn derive_serialize(input: TokenStream) -> TokenStream {
453     let input = parse_macro_input!(input as syn::DeriveInput);
454     let typename = &input.ident;
455
456     let code = serialize_args(MtArgs::from_derive_input(&input), |args| {
457         match &input.data {
458             syn::Data::Enum(e) => {
459                 let repr = get_repr(&input, args);
460                 let mut variants = TokStr::new();
461
462                 iter_variants(e, args, |v, discr| {
463                     let (fields, fields_struct) = get_fields_struct(&v.fields);
464                     let code =
465                         serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields));
466                     let ident = &v.ident;
467
468                     variants.extend(quote! {
469                                         #typename::#ident #fields_struct => {
470                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
471                                                 #code
472                                         }
473                                 });
474                 });
475
476                 quote! {
477                     match self {
478                         #variants
479                     }
480                 }
481             }
482             syn::Data::Struct(s) => {
483                 serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f }))
484             }
485             _ => {
486                 panic!("only enum and struct supported");
487             }
488         }
489     });
490
491     quote! {
492                 #[automatically_derived]
493                 impl mt_ser::MtSerialize for #typename {
494                         fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
495                                 #code
496
497                                 Ok(())
498                         }
499                 }
500         }.into()
501 }
502
503 #[proc_macro_derive(MtDeserialize, attributes(mt))]
504 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
505     let input = parse_macro_input!(input as syn::DeriveInput);
506     let typename = &input.ident;
507
508     let code = deserialize_args(MtArgs::from_derive_input(&input), |args| {
509         match &input.data {
510             syn::Data::Enum(e) => {
511                 let repr = get_repr(&input, args);
512
513                 let mut consts = TokStr::new();
514                 let mut arms = TokStr::new();
515
516                 iter_variants(e, args, |v, discr| {
517                     let ident = &v.ident;
518                     let (fields, fields_struct) = get_fields_struct(&v.fields);
519                     let code = deserialize_args(MtArgs::from_variant(v), |_| {
520                         let fields_code = deserialize_fields(&fields);
521
522                         quote! {
523                             #fields_code
524                             Ok(Self::#ident #fields_struct)
525                         }
526                     });
527
528                     consts.extend(quote! {
529                         const #ident: #repr = #discr;
530                     });
531
532                     arms.extend(quote! {
533                         #ident => { #code }
534                     });
535                 });
536
537                 let type_str = typename.to_string();
538                 let discr_match = if args.string_repr {
539                     quote! {
540                         let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
541                         match __discr.as_str()
542                     }
543                 } else {
544                     quote! {
545                         let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
546                         match __discr
547                     }
548                 };
549
550                 quote! {
551                     #consts
552
553                     #discr_match {
554                         #arms
555                         _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
556                     }
557                 }
558             }
559             syn::Data::Struct(s) => {
560                 let (fields, fields_struct) = get_fields_struct(&s.fields);
561                 let code = deserialize_fields(&fields);
562
563                 quote! {
564                     #code
565                     Ok(Self #fields_struct)
566                 }
567             }
568             _ => {
569                 panic!("only enum and struct supported");
570             }
571         }
572     });
573
574     quote! {
575                 #[automatically_derived]
576                 impl mt_ser::MtDeserialize for #typename {
577                         #[allow(non_upper_case_globals)]
578                         fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
579                                 #code
580                         }
581                 }
582         }.into()
583 }