]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
cargo fmt
[mt_ser.git] / derive / src / lib.rs
1 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
2 use proc_macro::TokenStream;
3 use proc_macro2::TokenStream as TokStr;
4 use quote::{quote, ToTokens};
5 use syn::{parse_macro_input, parse_quote};
6
7 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
8 #[darling(rename_all = "snake_case")]
9 enum To {
10     Clt,
11     Srv,
12 }
13
14 #[derive(Debug, FromMeta)]
15 struct MacroArgs {
16     to: To,
17     repr: Option<syn::Type>,
18     tag: Option<String>,
19     content: Option<String>,
20     #[darling(default)]
21     custom: bool,
22     #[darling(default)]
23     enumset: bool,
24 }
25
26 fn wrap_attr(attr: &mut syn::Attribute) {
27     let path = attr.path.clone();
28     let tokens = attr.tokens.clone();
29
30     match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
31         Some("mt") => {
32             *attr = parse_quote! {
33                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
34             };
35         }
36         Some("serde") => {
37             *attr = parse_quote! {
38                 #[cfg_attr(feature = "serde", #path #tokens)]
39             };
40         }
41         _ => {}
42     }
43 }
44
45 #[proc_macro_attribute]
46 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
47     let item2 = item.clone();
48
49     let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
50     let mut input = parse_macro_input!(item2 as syn::Item);
51
52     let args = match MacroArgs::from_list(&attr_args) {
53         Ok(v) => v,
54         Err(e) => {
55             return TokenStream::from(e.write_errors());
56         }
57     };
58
59     let (serializer, deserializer) = match args.to {
60         To::Clt => ("server", "client"),
61         To::Srv => ("client", "server"),
62     };
63
64     let mut out = quote! {
65         #[derive(Debug)]
66         #[cfg_attr(feature = "random", derive(GenerateRandom))]
67         #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
68     };
69
70     macro_rules! iter {
71         ($t:expr, $f:expr) => {
72             $t.iter_mut().for_each($f)
73         };
74     }
75
76     match &mut input {
77         syn::Item::Enum(e) => {
78             iter!(e.attrs, wrap_attr);
79             iter!(e.variants, |v| {
80                 iter!(v.attrs, wrap_attr);
81                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
82             });
83
84             let repr = args.repr.expect("missing repr for enum");
85
86             if args.enumset {
87                 let repr_str = repr.to_token_stream().to_string();
88
89                 out.extend(quote! {
90                     #[derive(EnumSetType)]
91                     #[enumset(repr = #repr_str, serialize_as_map)]
92                 })
93             } else {
94                 let has_payload = e
95                     .variants
96                     .iter()
97                     .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
98                     .is_some();
99
100                 if has_payload {
101                     let tag = args.tag.expect("missing tag for enum with payload");
102
103                     out.extend(quote! {
104                         #[cfg_attr(feature = "serde", serde(tag = #tag))]
105                     });
106
107                     if let Some(content) = args.content {
108                         out.extend(quote! {
109                             #[cfg_attr(feature = "serde", serde(content = #content))]
110                         });
111                     }
112                 } else {
113                     out.extend(quote! {
114                         #[derive(Copy, Eq)]
115                     });
116                 }
117
118                 out.extend(quote! {
119                     #[repr(#repr)]
120                     #[derive(Clone, PartialEq)]
121                 });
122
123                 if !args.custom {
124                     out.extend(quote! {
125                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
126                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
127                     });
128                 }
129             }
130
131             out.extend(quote! {
132                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
133             });
134         }
135         syn::Item::Struct(s) => {
136             iter!(s.attrs, wrap_attr);
137             iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
138
139             out.extend(quote! {
140                 #[derive(Clone, PartialEq)]
141             });
142
143             if !args.custom {
144                 out.extend(quote! {
145                     #[cfg_attr(feature = #serializer, derive(MtSerialize))]
146                     #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
147                 });
148             }
149         }
150         _ => panic!("only enum and struct supported"),
151     }
152
153     out.extend(input.to_token_stream());
154     out.into()
155 }
156
157 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
158 #[darling(attributes(mt))]
159 #[darling(default)]
160 struct MtArgs {
161     #[darling(multiple)]
162     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
163
164     #[darling(multiple)]
165     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
166
167     size: Option<syn::Type>, // must implement MtCfg
168
169     len: Option<syn::Type>, // must implement MtCfg
170
171     zlib: bool,
172     zstd: bool,    // TODO
173     default: bool, // type must implement Default
174 }
175
176 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
177
178 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
179     match fields {
180         syn::Fields::Named(fs) => fs
181             .named
182             .iter()
183             .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
184             .collect(),
185         syn::Fields::Unnamed(fs) => fs
186             .unnamed
187             .iter()
188             .enumerate()
189             .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
190             .collect(),
191         syn::Fields::Unit => Vec::new(),
192     }
193 }
194
195 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
196     match res {
197         Ok(args) => {
198             let mut code = body(&args);
199
200             if args.zlib {
201                 code = quote! {
202                     let mut __writer = {
203                         let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
204                             __writer,
205                             mt_ser::flate2::Compression::default(),
206                         );
207                         let __writer = &mut __stream;
208                         #code
209                         __stream.finish()?
210                     };
211                 };
212             }
213
214             if let Some(size) = args.size {
215                 code = quote! {
216                     mt_ser::MtSerialize::mt_serialize::<#size>(&{
217                         let mut __buf = Vec::new();
218                         let __writer = &mut __buf;
219                         #code
220                         __buf
221                     }, __writer)?;
222                 };
223             }
224
225             for x in args.const_before.iter().rev() {
226                 code = quote! {
227                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
228                     #code
229                 }
230             }
231
232             for x in args.const_after.iter() {
233                 code = quote! {
234                     #code
235                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
236                 }
237             }
238
239             code
240         }
241         Err(e) => return e.write_errors(),
242     }
243 }
244
245 fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
246     match res {
247         Ok(args) => {
248             let mut code = body(&args);
249
250             if args.zlib {
251                 code = quote! {
252                     {
253                         let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
254                             mt_ser::WrapRead(__reader));
255                         let __reader = &mut __owned_reader;
256
257                         #code
258                     }
259                 }
260             }
261
262             if let Some(size) = args.size {
263                 code = quote! {
264                     #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
265                         let mut __owned_reader = std::io::Read::take(
266                             mt_ser::WrapRead(__reader), size as u64);
267                         let __reader = &mut __owned_reader;
268
269                         #code
270                     })
271                 };
272             }
273
274             let impl_const = |value: TokStr| {
275                 quote! {
276                     {
277                         fn deserialize_same_type<T: MtDeserialize>(
278                             _: &T,
279                             reader: &mut impl std::io::Read
280                         ) -> Result<T, DeserializeError> {
281                             T::mt_deserialize::<mt_ser::DefCfg>(reader)
282                         }
283
284                         deserialize_same_type(&want, __reader)
285                             .and_then(|got| {
286                                 if want == got {
287                                     #value
288                                 } else {
289                                     Err(mt_ser::DeserializeError::InvalidConst(
290                                         Box::new(want), Box::new(got)
291                                     ))
292                                 }
293                             })
294                     }
295                 }
296             };
297
298             for want in args.const_before.iter().rev() {
299                 let imp = impl_const(code);
300                 code = quote! {
301                     {
302                         let want = #want;
303                         #imp
304                     }
305                 };
306             }
307
308             for want in args.const_after.iter() {
309                 let imp = impl_const(quote! { Ok(value) });
310                 code = quote! {
311                     {
312                         let want = #want;
313                         #code.and_then(|value| { #imp })
314                     }
315                 };
316             }
317
318             code
319         }
320         Err(e) => return e.write_errors(),
321     }
322 }
323
324 fn serialize_fields(fields: &Fields) -> TokStr {
325     fields
326         .iter()
327         .map(|(ident, field)| {
328             serialize_args(MtArgs::from_field(field), |args| {
329                 let def = parse_quote! { mt_ser::DefCfg };
330                 let len = args.len.as_ref().unwrap_or(&def);
331                 quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
332             })
333         })
334         .collect()
335 }
336
337 fn deserialize_fields(fields: &Fields) -> TokStr {
338     fields
339         .iter()
340         .map(|(ident, field)| {
341             let code = deserialize_args(MtArgs::from_field(field), |args| {
342                 let def = parse_quote! { mt_ser::DefCfg };
343                 let len = args.len.as_ref().unwrap_or(&def);
344                 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
345
346                 if args.default {
347                     code = quote! {
348                         mt_ser::OrDefault::or_default(#code)
349                     };
350                 }
351
352                 code
353             });
354
355             quote! {
356                 let #ident = #code?;
357             }
358         })
359         .collect()
360 }
361
362 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
363     let ident_fn = match input {
364         syn::Fields::Unnamed(_) => |f| {
365             quote! {
366                 mt_ser::paste::paste! { [<field_ #f>] }
367             }
368         },
369         _ => |f| quote! { #f },
370     };
371
372     let fields = get_fields(input, ident_fn);
373     let fields_comma: TokStr = fields
374         .iter()
375         .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
376
377     let fields_struct = match input {
378         syn::Fields::Named(_) => quote! { { #fields_comma } },
379         syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
380         syn::Fields::Unit => TokStr::new(),
381     };
382
383     (fields, fields_struct)
384 }
385
386 fn get_repr(input: &syn::DeriveInput) -> syn::Type {
387     input
388         .attrs
389         .iter()
390         .find(|a| a.path.is_ident("repr"))
391         .expect("missing repr")
392         .parse_args()
393         .expect("invalid repr")
394 }
395
396 #[proc_macro_derive(MtSerialize, attributes(mt))]
397 pub fn derive_serialize(input: TokenStream) -> TokenStream {
398     let input = parse_macro_input!(input as syn::DeriveInput);
399     let typename = &input.ident;
400
401     let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
402         syn::Data::Enum(e) => {
403             let repr = get_repr(&input);
404             let variants: TokStr = e.variants
405                                 .iter()
406                                 .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| {
407                                         let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
408                                         let (fields, fields_struct) = get_fields_struct(&v.fields);
409
410                                         let code = serialize_args(MtArgs::from_variant(v), |_|
411                                                 serialize_fields(&fields));
412                                         let variant = &v.ident;
413
414                                         (
415                                                 parse_quote! { 1 + #discr },
416                                                 quote! {
417                                                         #before
418                                                         #typename::#variant #fields_struct => {
419                                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
420                                                                 #code
421                                                         }
422                                                 }
423                                         )
424                                 }).1;
425
426             quote! {
427                 match self {
428                     #variants
429                 }
430             }
431         }
432         syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
433         _ => {
434             panic!("only enum and struct supported");
435         }
436     });
437
438     quote! {
439                 #[automatically_derived]
440                 impl mt_ser::MtSerialize for #typename {
441                         fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
442                                 #code
443
444                                 Ok(())
445                         }
446                 }
447         }.into()
448 }
449
450 #[proc_macro_derive(MtDeserialize, attributes(mt))]
451 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
452     let input = parse_macro_input!(input as syn::DeriveInput);
453     let typename = &input.ident;
454
455     let code = deserialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
456         syn::Data::Enum(e) => {
457             let repr = get_repr(&input);
458             let type_str = typename.to_string();
459
460             let mut consts = TokStr::new();
461             let mut arms = TokStr::new();
462             let mut discr = parse_quote! { 0 };
463
464             for v in e.variants.iter() {
465                 discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
466
467                 let ident = &v.ident;
468                 let (fields, fields_struct) = get_fields_struct(&v.fields);
469                 let code = deserialize_args(MtArgs::from_variant(v), |_| {
470                     let fields_code = deserialize_fields(&fields);
471
472                     quote! {
473                         #fields_code
474                         Ok(Self::#ident #fields_struct)
475                     }
476                 });
477
478                 consts.extend(quote! {
479                     const #ident: #repr = #discr;
480                 });
481
482                 arms.extend(quote! {
483                     #ident => { #code }
484                 });
485
486                 discr = parse_quote! { 1 + #discr };
487             }
488
489             quote! {
490                 #consts
491
492                 match mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)? {
493                     #arms
494                     x => Err(mt_ser::DeserializeError::InvalidEnumVariant(#type_str, x as u64))
495                 }
496             }
497         }
498         syn::Data::Struct(s) => {
499             let (fields, fields_struct) = get_fields_struct(&s.fields);
500             let code = deserialize_fields(&fields);
501
502             quote! {
503                 #code
504                 Ok(Self #fields_struct)
505             }
506         }
507         _ => {
508             panic!("only enum and struct supported");
509         }
510     });
511
512     quote! {
513                 #[automatically_derived]
514                 impl mt_ser::MtDeserialize for #typename {
515                         #[allow(non_upper_case_globals)]
516                         fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
517                                 #code
518                         }
519                 }
520         }.into()
521 }