]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
0ccc8cacb9b14ff4b97d853ccac577b672b31700
[mt_ser.git] / derive / src / lib.rs
1 use convert_case::{Case, Casing};
2 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
3 use proc_macro::TokenStream;
4 use proc_macro2::TokenStream as TokStr;
5 use quote::{quote, ToTokens};
6 use syn::{parse_macro_input, parse_quote};
7
8 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
9 #[darling(rename_all = "snake_case")]
10 enum To {
11     Clt,
12     Srv,
13 }
14
15 #[derive(Debug, FromMeta)]
16 struct MacroArgs {
17     to: To,
18     repr: Option<syn::Type>,
19     tag: Option<String>,
20     content: Option<String>,
21     #[darling(default)]
22     custom: bool,
23     #[darling(default)]
24     enumset: bool,
25 }
26
27 fn wrap_attr(attr: &mut syn::Attribute) {
28     let path = attr.path.clone();
29     let tokens = attr.tokens.clone();
30
31     match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
32         Some("mt") => {
33             *attr = parse_quote! {
34                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
35             };
36         }
37         Some("serde") => {
38             *attr = parse_quote! {
39                 #[cfg_attr(feature = "serde", #path #tokens)]
40             };
41         }
42         _ => {}
43     }
44 }
45
46 #[proc_macro_attribute]
47 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
48     let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
49     let mut input = parse_macro_input!(item as syn::Item);
50
51     let args = match MacroArgs::from_list(&attr_args) {
52         Ok(v) => v,
53         Err(e) => {
54             return TokenStream::from(e.write_errors());
55         }
56     };
57
58     let (serializer, deserializer) = match args.to {
59         To::Clt => ("server", "client"),
60         To::Srv => ("client", "server"),
61     };
62
63     let mut out = quote! {
64         #[derive(Debug)]
65         #[cfg_attr(feature = "random", derive(GenerateRandom))]
66         #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67     };
68
69     macro_rules! iter {
70         ($t:expr, $f:expr) => {
71             $t.iter_mut().for_each($f)
72         };
73     }
74
75     match &mut input {
76         syn::Item::Enum(e) => {
77             iter!(e.attrs, wrap_attr);
78             iter!(e.variants, |v| {
79                 iter!(v.attrs, wrap_attr);
80                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
81             });
82
83             if args.enumset {
84                 out.extend(quote! {
85                     #[derive(EnumSetType)]
86                     #[enumset(serialize_as_map)]
87                 });
88
89                 if let Some(repr) = args.repr {
90                     let repr_str = repr.to_token_stream().to_string();
91
92                     out.extend(quote! {
93                         #[enumset(repr = #repr_str)]
94                     });
95                 } else if !args.custom {
96                     panic!("missing repr for enum");
97                 }
98             } else {
99                 let has_payload = e
100                     .variants
101                     .iter()
102                     .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
103                     .is_some();
104
105                 if has_payload {
106                     let tag = args.tag.expect("missing tag for enum with payload");
107
108                     out.extend(quote! {
109                         #[cfg_attr(feature = "serde", serde(tag = #tag))]
110                     });
111
112                     if let Some(content) = args.content {
113                         out.extend(quote! {
114                             #[cfg_attr(feature = "serde", serde(content = #content))]
115                         });
116                     }
117                 } else {
118                     out.extend(quote! {
119                         #[derive(Copy, Eq)]
120                     });
121                 }
122
123                 out.extend(quote! {
124                     #[derive(Clone, PartialEq)]
125                 });
126
127                 if !args.custom {
128                     out.extend(quote! {
129                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
130                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
131                     });
132                 }
133
134                 if let Some(repr) = args.repr {
135                     if repr == parse_quote! { str } {
136                         out.extend(quote! {
137                                                         #[cfg_attr(any(feature = "client", feature = "server"), mt(string_repr))]
138                                                 });
139                     } else {
140                         out.extend(quote! {
141                             #[repr(#repr)]
142                         });
143                     }
144                 } else if !args.custom {
145                     panic!("missing repr for enum");
146                 }
147             }
148
149             out.extend(quote! {
150                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
151             });
152         }
153         syn::Item::Struct(s) => {
154             iter!(s.attrs, wrap_attr);
155             iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
156
157             out.extend(quote! {
158                 #[derive(Clone, PartialEq)]
159             });
160
161             if !args.custom {
162                 out.extend(quote! {
163                     #[cfg_attr(feature = #serializer, derive(MtSerialize))]
164                     #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
165                 });
166             }
167         }
168         _ => panic!("only enum and struct supported"),
169     }
170
171     out.extend(input.to_token_stream());
172     out.into()
173 }
174
175 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
176 #[darling(attributes(mt))]
177 #[darling(default)]
178 struct MtArgs {
179     #[darling(multiple)]
180     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
181     #[darling(multiple)]
182     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
183     size: Option<syn::Type>, // must implement MtCfg
184     len: Option<syn::Type>,  // must implement MtCfg
185     default: bool,           // type must implement Default
186     string_repr: bool,       // for enums
187     zlib: bool,
188     zstd: bool,
189     typename: Option<syn::Ident>, // remote derive
190     bounds: Option<syn::WhereClause>,
191 }
192
193 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
194
195 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
196     match fields {
197         syn::Fields::Named(fs) => fs
198             .named
199             .iter()
200             .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
201             .collect(),
202         syn::Fields::Unnamed(fs) => fs
203             .unnamed
204             .iter()
205             .enumerate()
206             .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
207             .collect(),
208         syn::Fields::Unit => Vec::new(),
209     }
210 }
211
212 fn serialize_args(args: &MtArgs, code: &mut TokStr) {
213     macro_rules! impl_compress {
214         ($create:expr) => {
215             *code = quote! {
216                 let mut __writer = {
217                     let mut __stream = $create;
218                     let __writer = &mut __stream;
219                     #code
220                     __stream.finish()?
221                 };
222             };
223         };
224     }
225
226     if args.zlib {
227         impl_compress!(mt_ser::flate2::write::ZlibEncoder::new(
228             __writer,
229             mt_ser::flate2::Compression::default()
230         ));
231     }
232
233     if args.zstd {
234         impl_compress!(mt_ser::zstd::stream::write::Encoder::new(__writer, 0)?);
235     }
236
237     if let Some(size) = &args.size {
238         *code = quote! {
239             mt_ser::MtSerialize::mt_serialize::<#size>(&{
240                 let mut __buf = Vec::new();
241                 let __writer = &mut __buf;
242                 #code
243                 __buf
244             }, __writer)?;
245         };
246     }
247
248     for x in args.const_before.iter().rev() {
249         *code = quote! {
250             #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
251             #code
252         }
253     }
254
255     for x in args.const_after.iter() {
256         *code = quote! {
257             #code
258             #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
259         }
260     }
261 }
262
263 fn deserialize_args(args: &MtArgs, code: &mut TokStr) {
264     macro_rules! impl_compress {
265         ($create:expr) => {
266             *code = quote! {
267                 {
268                     let mut __owned_reader = $create;
269                     let __reader = &mut __owned_reader;
270
271                     #code
272                 }
273             }
274         };
275     }
276
277     if args.zlib {
278         impl_compress!(mt_ser::flate2::read::ZlibDecoder::new(mt_ser::WrapRead(
279             __reader
280         )));
281     }
282
283     if args.zstd {
284         impl_compress!(mt_ser::zstd::stream::read::Decoder::new(mt_ser::WrapRead(
285             __reader
286         ))?);
287     }
288
289     if let Some(size) = &args.size {
290         *code = quote! {
291             #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
292                 let mut __owned_reader = std::io::Read::take(
293                     mt_ser::WrapRead(__reader), size as u64);
294                 let __reader = &mut __owned_reader;
295
296                 #code
297             })
298         };
299     }
300
301     let impl_const = |value: &TokStr| {
302         quote! {
303             {
304                 fn deserialize_same_type<T: MtDeserialize>(
305                     _: &T,
306                     reader: &mut impl std::io::Read
307                 ) -> Result<T, mt_ser::DeserializeError> {
308                     T::mt_deserialize::<mt_ser::DefCfg>(reader)
309                 }
310
311                 deserialize_same_type(&want, __reader)
312                     .and_then(|got| {
313                         if want == got {
314                             #value
315                         } else {
316                             Err(mt_ser::DeserializeError::InvalidConst(
317                                 Box::new(want), Box::new(got)
318                             ))
319                         }
320                     })
321             }
322         }
323     };
324
325     for want in args.const_before.iter().rev() {
326         let imp = impl_const(&code);
327         *code = quote! {
328             {
329                 let want = #want;
330                 #imp
331             }
332         };
333     }
334
335     for want in args.const_after.iter() {
336         let imp = impl_const(&quote! { Ok(value) });
337         *code = quote! {
338             {
339                 let want = #want;
340                 #code.and_then(|value| { #imp })
341             }
342         };
343     }
344 }
345
346 fn serialize_fields(fields: &Fields) -> TokStr {
347     fields
348         .iter()
349         .map(|(ident, field)| {
350             let args = MtArgs::from_field(field).unwrap();
351             let def = parse_quote! { mt_ser::DefCfg };
352             let len = args.len.as_ref().unwrap_or(&def);
353
354             let mut code = quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; };
355             serialize_args(&args, &mut code);
356
357             code
358         })
359         .collect()
360 }
361
362 fn deserialize_fields(fields: &Fields) -> TokStr {
363     fields
364         .iter()
365         .map(|(ident, field)| {
366             let args = MtArgs::from_field(field).unwrap();
367
368             let def = parse_quote! { mt_ser::DefCfg };
369             let len = args.len.as_ref().unwrap_or(&def);
370             let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
371
372             if args.default {
373                 code = quote! {
374                     mt_ser::OrDefault::or_default(#code)
375                 };
376             }
377
378             deserialize_args(&args, &mut code);
379
380             quote! {
381                 let #ident = #code?;
382             }
383         })
384         .collect()
385 }
386
387 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
388     let ident_fn = match input {
389         syn::Fields::Unnamed(_) => |f| {
390             quote! {
391                 mt_ser::paste::paste! { [<field_ #f>] }
392             }
393         },
394         _ => |f| quote! { #f },
395     };
396
397     let fields = get_fields(input, ident_fn);
398     let fields_comma: TokStr = fields
399         .iter()
400         .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
401
402     let fields_struct = match input {
403         syn::Fields::Named(_) => quote! { { #fields_comma } },
404         syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
405         syn::Fields::Unit => TokStr::new(),
406     };
407
408     (fields, fields_struct)
409 }
410
411 fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
412     if args.string_repr {
413         parse_quote! { &str }
414     } else {
415         input
416             .attrs
417             .iter()
418             .find(|a| a.path.is_ident("repr"))
419             .expect("missing repr")
420             .parse_args()
421             .expect("invalid repr")
422     }
423 }
424
425 fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
426     let mut discr = parse_quote! { 0 };
427
428     for v in e.variants.iter() {
429         discr = if args.string_repr {
430             let lit = v.ident.to_string().to_case(Case::Snake);
431             parse_quote! { #lit }
432         } else {
433             v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
434         };
435
436         f(v, &discr);
437
438         discr = parse_quote! { 1 + #discr };
439     }
440 }
441
442 fn make_impl(
443     traitname: TokStr,
444     input: &syn::DeriveInput,
445     typename: &syn::Ident,
446     args: &MtArgs,
447     code: TokStr,
448 ) -> TokenStream {
449     let generics = &input.generics;
450     let bounds = args.bounds.clone().or_else(|| {
451         if generics.params.is_empty() {
452             None
453         } else {
454             Some(
455                 syn::parse(
456                     generics
457                         .params
458                         .iter()
459                         .rfold(quote! { where }, |before, t| match t {
460                             syn::GenericParam::Type(x) => quote! { #before #x: #traitname, },
461                             _ => before,
462                         })
463                         .into(),
464                 )
465                 .expect("invalid where clause"),
466             )
467         }
468     });
469
470     quote! {
471         #[automatically_derived]
472         impl #generics #traitname for #typename #generics #bounds { #code }
473     }
474     .into()
475 }
476
477 #[proc_macro_derive(MtSerialize, attributes(mt))]
478 pub fn derive_serialize(input: TokenStream) -> TokenStream {
479     let input = parse_macro_input!(input as syn::DeriveInput);
480     let args = MtArgs::from_derive_input(&input).unwrap();
481     let typename = args.typename.as_ref().unwrap_or(&input.ident);
482
483     let mut code = match &input.data {
484         syn::Data::Enum(e) => {
485             let repr = get_repr(&input, &args);
486             let mut variants = TokStr::new();
487
488             iter_variants(e, &args, |v, discr| {
489                 let args = MtArgs::from_variant(v).unwrap();
490
491                 let (fields, fields_struct) = get_fields_struct(&v.fields);
492
493                 let mut code = serialize_fields(&fields);
494                 serialize_args(&args, &mut code);
495
496                 let ident = &v.ident;
497
498                 variants.extend(quote! {
499                                         #typename::#ident #fields_struct => {
500                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
501                                                 #code
502                                         }
503                                 });
504             });
505
506             quote! {
507                 match self {
508                     #variants
509                 }
510             }
511         }
512         syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
513         _ => {
514             panic!("only enum and struct supported");
515         }
516     };
517
518     serialize_args(&args, &mut code);
519
520     make_impl(
521         quote! { mt_ser::MtSerialize },
522         &input,
523         typename,
524         &args,
525         quote! {
526             fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
527                 #code
528
529                 Ok(())
530             }
531         },
532     )
533 }
534
535 #[proc_macro_derive(MtDeserialize, attributes(mt))]
536 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
537     let input = parse_macro_input!(input as syn::DeriveInput);
538     let args = MtArgs::from_derive_input(&input).unwrap();
539     let typename = args.typename.as_ref().unwrap_or(&input.ident);
540
541     let mut code = match &input.data {
542         syn::Data::Enum(e) => {
543             let repr = get_repr(&input, &args);
544
545             let mut consts = TokStr::new();
546             let mut arms = TokStr::new();
547
548             iter_variants(e, &args, |v, discr| {
549                 let args = MtArgs::from_variant(v).unwrap();
550
551                 let ident = &v.ident;
552                 let (fields, fields_struct) = get_fields_struct(&v.fields);
553                 let mut code = deserialize_fields(&fields);
554                 code = quote! {
555                     #code
556                     Ok(Self::#ident #fields_struct)
557                 };
558
559                 deserialize_args(&args, &mut code);
560
561                 consts.extend(quote! {
562                     const #ident: #repr = #discr;
563                 });
564
565                 arms.extend(quote! {
566                     #ident => { #code }
567                 });
568             });
569
570             let type_str = typename.to_string();
571             let discr_match = if args.string_repr {
572                 quote! {
573                     let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
574                     match __discr.as_str()
575                 }
576             } else {
577                 quote! {
578                     let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
579                     match __discr
580                 }
581             };
582
583             quote! {
584                 #consts
585
586                 #discr_match {
587                     #arms
588                     _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
589                 }
590             }
591         }
592         syn::Data::Struct(s) => {
593             let (fields, fields_struct) = get_fields_struct(&s.fields);
594             let code = deserialize_fields(&fields);
595
596             quote! {
597                 #code
598                 Ok(Self #fields_struct)
599             }
600         }
601         _ => {
602             panic!("only enum and struct supported");
603         }
604     };
605
606     deserialize_args(&args, &mut code);
607
608     make_impl(
609         quote! { mt_ser::MtDeserialize },
610         &input,
611         typename,
612         &args,
613         quote! {
614             #[allow(non_upper_case_globals)]
615             fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
616                 #code
617             }
618         },
619     )
620 }