]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
82d964421b8628e624ebc18d2ed49ce2b86c8929
[mt_ser.git] / derive / src / lib.rs
1 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
2 use proc_macro::TokenStream;
3 use proc_macro2::TokenStream as TokStr;
4 use quote::{quote, ToTokens};
5 use syn::{parse_macro_input, parse_quote};
6
7 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
8 #[darling(rename_all = "snake_case")]
9 enum To {
10     Clt,
11     Srv,
12 }
13
14 #[derive(Debug, FromMeta)]
15 struct MacroArgs {
16     to: To,
17     repr: Option<syn::Type>,
18     tag: Option<String>,
19     content: Option<String>,
20     #[darling(default)]
21     custom: bool,
22     #[darling(default)]
23     enumset: bool,
24 }
25
26 fn wrap_attr(attr: &mut syn::Attribute) {
27     let path = attr.path.clone();
28     let tokens = attr.tokens.clone();
29
30     match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
31         Some("mt") => {
32             *attr = parse_quote! {
33                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
34             };
35         }
36         Some("serde") => {
37             *attr = parse_quote! {
38                 #[cfg_attr(feature = "serde", #path #tokens)]
39             };
40         }
41         _ => {}
42     }
43 }
44
45 #[proc_macro_attribute]
46 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
47     let item2 = item.clone();
48
49     let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
50     let mut input = parse_macro_input!(item2 as syn::Item);
51
52     let args = match MacroArgs::from_list(&attr_args) {
53         Ok(v) => v,
54         Err(e) => {
55             return TokenStream::from(e.write_errors());
56         }
57     };
58
59     let (serializer, deserializer) = match args.to {
60         To::Clt => ("server", "client"),
61         To::Srv => ("client", "server"),
62     };
63
64     let mut out = quote! {
65         #[derive(Debug)]
66         #[cfg_attr(feature = "random", derive(GenerateRandom))]
67         #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
68     };
69
70     macro_rules! iter {
71         ($t:expr, $f:expr) => {
72             $t.iter_mut().for_each($f)
73         };
74     }
75
76     match &mut input {
77         syn::Item::Enum(e) => {
78             iter!(e.attrs, wrap_attr);
79             iter!(e.variants, |v| {
80                 iter!(v.attrs, wrap_attr);
81                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
82             });
83
84             if args.enumset {
85                 let repr_str = args
86                     .repr
87                     .expect("missing repr for enum")
88                     .to_token_stream()
89                     .to_string();
90
91                 out.extend(quote! {
92                     #[derive(EnumSetType)]
93                     #[enumset(repr = #repr_str, serialize_as_map)]
94                 })
95             } else {
96                 let has_payload = e
97                     .variants
98                     .iter()
99                     .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
100                     .is_some();
101
102                 if has_payload {
103                     let tag = args.tag.expect("missing tag for enum with payload");
104
105                     out.extend(quote! {
106                         #[cfg_attr(feature = "serde", serde(tag = #tag))]
107                     });
108
109                     if let Some(content) = args.content {
110                         out.extend(quote! {
111                             #[cfg_attr(feature = "serde", serde(content = #content))]
112                         });
113                     }
114                 } else {
115                     out.extend(quote! {
116                         #[derive(Copy, Eq)]
117                     });
118                 }
119
120                 if let Some(repr) = args.repr {
121                     out.extend(quote! {
122                         #[repr(#repr)]
123                     });
124                 } else if !args.custom {
125                     panic!("missing repr for enum");
126                 }
127
128                 out.extend(quote! {
129                     #[derive(Clone, PartialEq)]
130                 });
131
132                 if !args.custom {
133                     out.extend(quote! {
134                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
135                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
136                     });
137                 }
138             }
139
140             out.extend(quote! {
141                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
142             });
143         }
144         syn::Item::Struct(s) => {
145             iter!(s.attrs, wrap_attr);
146             iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
147
148             out.extend(quote! {
149                 #[derive(Clone, PartialEq)]
150             });
151
152             if !args.custom {
153                 out.extend(quote! {
154                     #[cfg_attr(feature = #serializer, derive(MtSerialize))]
155                     #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
156                 });
157             }
158         }
159         _ => panic!("only enum and struct supported"),
160     }
161
162     out.extend(input.to_token_stream());
163     out.into()
164 }
165
166 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
167 #[darling(attributes(mt))]
168 #[darling(default)]
169 struct MtArgs {
170     #[darling(multiple)]
171     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
172
173     #[darling(multiple)]
174     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
175
176     size: Option<syn::Type>, // must implement MtCfg
177
178     len: Option<syn::Type>, // must implement MtCfg
179
180     zlib: bool,
181     zstd: bool,    // TODO
182     default: bool, // type must implement Default
183 }
184
185 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
186
187 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
188     match fields {
189         syn::Fields::Named(fs) => fs
190             .named
191             .iter()
192             .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
193             .collect(),
194         syn::Fields::Unnamed(fs) => fs
195             .unnamed
196             .iter()
197             .enumerate()
198             .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
199             .collect(),
200         syn::Fields::Unit => Vec::new(),
201     }
202 }
203
204 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
205     match res {
206         Ok(args) => {
207             let mut code = body(&args);
208
209             if args.zlib {
210                 code = quote! {
211                     let mut __writer = {
212                         let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
213                             __writer,
214                             mt_ser::flate2::Compression::default(),
215                         );
216                         let __writer = &mut __stream;
217                         #code
218                         __stream.finish()?
219                     };
220                 };
221             }
222
223             if let Some(size) = args.size {
224                 code = quote! {
225                     mt_ser::MtSerialize::mt_serialize::<#size>(&{
226                         let mut __buf = Vec::new();
227                         let __writer = &mut __buf;
228                         #code
229                         __buf
230                     }, __writer)?;
231                 };
232             }
233
234             for x in args.const_before.iter().rev() {
235                 code = quote! {
236                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
237                     #code
238                 }
239             }
240
241             for x in args.const_after.iter() {
242                 code = quote! {
243                     #code
244                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
245                 }
246             }
247
248             code
249         }
250         Err(e) => return e.write_errors(),
251     }
252 }
253
254 fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
255     match res {
256         Ok(args) => {
257             let mut code = body(&args);
258
259             if args.zlib {
260                 code = quote! {
261                     {
262                         let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
263                             mt_ser::WrapRead(__reader));
264                         let __reader = &mut __owned_reader;
265
266                         #code
267                     }
268                 }
269             }
270
271             if let Some(size) = args.size {
272                 code = quote! {
273                     #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
274                         let mut __owned_reader = std::io::Read::take(
275                             mt_ser::WrapRead(__reader), size as u64);
276                         let __reader = &mut __owned_reader;
277
278                         #code
279                     })
280                 };
281             }
282
283             let impl_const = |value: TokStr| {
284                 quote! {
285                     {
286                         fn deserialize_same_type<T: MtDeserialize>(
287                             _: &T,
288                             reader: &mut impl std::io::Read
289                         ) -> Result<T, mt_ser::DeserializeError> {
290                             T::mt_deserialize::<mt_ser::DefCfg>(reader)
291                         }
292
293                         deserialize_same_type(&want, __reader)
294                             .and_then(|got| {
295                                 if want == got {
296                                     #value
297                                 } else {
298                                     Err(mt_ser::DeserializeError::InvalidConst(
299                                         Box::new(want), Box::new(got)
300                                     ))
301                                 }
302                             })
303                     }
304                 }
305             };
306
307             for want in args.const_before.iter().rev() {
308                 let imp = impl_const(code);
309                 code = quote! {
310                     {
311                         let want = #want;
312                         #imp
313                     }
314                 };
315             }
316
317             for want in args.const_after.iter() {
318                 let imp = impl_const(quote! { Ok(value) });
319                 code = quote! {
320                     {
321                         let want = #want;
322                         #code.and_then(|value| { #imp })
323                     }
324                 };
325             }
326
327             code
328         }
329         Err(e) => return e.write_errors(),
330     }
331 }
332
333 fn serialize_fields(fields: &Fields) -> TokStr {
334     fields
335         .iter()
336         .map(|(ident, field)| {
337             serialize_args(MtArgs::from_field(field), |args| {
338                 let def = parse_quote! { mt_ser::DefCfg };
339                 let len = args.len.as_ref().unwrap_or(&def);
340                 quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
341             })
342         })
343         .collect()
344 }
345
346 fn deserialize_fields(fields: &Fields) -> TokStr {
347     fields
348         .iter()
349         .map(|(ident, field)| {
350             let code = deserialize_args(MtArgs::from_field(field), |args| {
351                 let def = parse_quote! { mt_ser::DefCfg };
352                 let len = args.len.as_ref().unwrap_or(&def);
353                 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
354
355                 if args.default {
356                     code = quote! {
357                         mt_ser::OrDefault::or_default(#code)
358                     };
359                 }
360
361                 code
362             });
363
364             quote! {
365                 let #ident = #code?;
366             }
367         })
368         .collect()
369 }
370
371 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
372     let ident_fn = match input {
373         syn::Fields::Unnamed(_) => |f| {
374             quote! {
375                 mt_ser::paste::paste! { [<field_ #f>] }
376             }
377         },
378         _ => |f| quote! { #f },
379     };
380
381     let fields = get_fields(input, ident_fn);
382     let fields_comma: TokStr = fields
383         .iter()
384         .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
385
386     let fields_struct = match input {
387         syn::Fields::Named(_) => quote! { { #fields_comma } },
388         syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
389         syn::Fields::Unit => TokStr::new(),
390     };
391
392     (fields, fields_struct)
393 }
394
395 fn get_repr(input: &syn::DeriveInput) -> syn::Type {
396     input
397         .attrs
398         .iter()
399         .find(|a| a.path.is_ident("repr"))
400         .expect("missing repr")
401         .parse_args()
402         .expect("invalid repr")
403 }
404
405 #[proc_macro_derive(MtSerialize, attributes(mt))]
406 pub fn derive_serialize(input: TokenStream) -> TokenStream {
407     let input = parse_macro_input!(input as syn::DeriveInput);
408     let typename = &input.ident;
409
410     let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
411         syn::Data::Enum(e) => {
412             let repr = get_repr(&input);
413             let variants: TokStr = e.variants
414                                 .iter()
415                                 .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| {
416                                         let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
417                                         let (fields, fields_struct) = get_fields_struct(&v.fields);
418
419                                         let code = serialize_args(MtArgs::from_variant(v), |_|
420                                                 serialize_fields(&fields));
421                                         let variant = &v.ident;
422
423                                         (
424                                                 parse_quote! { 1 + #discr },
425                                                 quote! {
426                                                         #before
427                                                         #typename::#variant #fields_struct => {
428                                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
429                                                                 #code
430                                                         }
431                                                 }
432                                         )
433                                 }).1;
434
435             quote! {
436                 match self {
437                     #variants
438                 }
439             }
440         }
441         syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
442         _ => {
443             panic!("only enum and struct supported");
444         }
445     });
446
447     quote! {
448                 #[automatically_derived]
449                 impl mt_ser::MtSerialize for #typename {
450                         fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
451                                 #code
452
453                                 Ok(())
454                         }
455                 }
456         }.into()
457 }
458
459 #[proc_macro_derive(MtDeserialize, attributes(mt))]
460 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
461     let input = parse_macro_input!(input as syn::DeriveInput);
462     let typename = &input.ident;
463
464     let code = deserialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
465         syn::Data::Enum(e) => {
466             let repr = get_repr(&input);
467             let type_str = typename.to_string();
468
469             let mut consts = TokStr::new();
470             let mut arms = TokStr::new();
471             let mut discr = parse_quote! { 0 };
472
473             for v in e.variants.iter() {
474                 discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
475
476                 let ident = &v.ident;
477                 let (fields, fields_struct) = get_fields_struct(&v.fields);
478                 let code = deserialize_args(MtArgs::from_variant(v), |_| {
479                     let fields_code = deserialize_fields(&fields);
480
481                     quote! {
482                         #fields_code
483                         Ok(Self::#ident #fields_struct)
484                     }
485                 });
486
487                 consts.extend(quote! {
488                     const #ident: #repr = #discr;
489                 });
490
491                 arms.extend(quote! {
492                     #ident => { #code }
493                 });
494
495                 discr = parse_quote! { 1 + #discr };
496             }
497
498             quote! {
499                 #consts
500
501                 match mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)? {
502                     #arms
503                     x => Err(mt_ser::DeserializeError::InvalidEnumVariant(#type_str, x as u64))
504                 }
505             }
506         }
507         syn::Data::Struct(s) => {
508             let (fields, fields_struct) = get_fields_struct(&s.fields);
509             let code = deserialize_fields(&fields);
510
511             quote! {
512                 #code
513                 Ok(Self #fields_struct)
514             }
515         }
516         _ => {
517             panic!("only enum and struct supported");
518         }
519     });
520
521     quote! {
522                 #[automatically_derived]
523                 impl mt_ser::MtDeserialize for #typename {
524                         #[allow(non_upper_case_globals)]
525                         fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
526                                 #code
527                         }
528                 }
529         }.into()
530 }