]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
Improve attributes
[mt_ser.git] / derive / src / lib.rs
1 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
2 use proc_macro::TokenStream;
3 use proc_macro2::TokenStream as TokStr;
4 use quote::{quote, ToTokens};
5 use syn::{parse_macro_input, parse_quote};
6
7 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
8 #[darling(rename_all = "snake_case")]
9 enum To {
10         Clt,
11         Srv,
12 }
13
14 #[derive(Debug, FromMeta)]
15 struct MacroArgs {
16         to: To,
17         repr: Option<syn::Type>,
18         tag: Option<String>,
19         content: Option<String>,
20         #[darling(default)]
21         custom: bool,
22         #[darling(default)]
23         enumset: bool,
24 }
25
26 fn wrap_attr(attr: &mut syn::Attribute) {
27         let path = attr.path.clone();
28         let tokens = attr.tokens.clone();
29
30         match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
31                 Some("mt") => {
32                         *attr = parse_quote! {
33                                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
34                         };
35                 }
36                 Some("serde") => {
37                         *attr = parse_quote! {
38                                 #[cfg_attr(feature = "serde", #path #tokens)]
39                         };
40                 }
41                 _ => {}
42         }
43 }
44
45 #[proc_macro_attribute]
46 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
47         let item2 = item.clone();
48
49         let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
50         let mut input = parse_macro_input!(item2 as syn::Item);
51
52         let args = match MacroArgs::from_list(&attr_args) {
53                 Ok(v) => v,
54                 Err(e) => {
55                         return TokenStream::from(e.write_errors());
56                 }
57         };
58
59         let (serializer, deserializer) = match args.to {
60                 To::Clt => ("server", "client"),
61                 To::Srv => ("client", "server"),
62         };
63
64         let mut out = quote! {
65                 #[derive(Debug)]
66                 #[cfg_attr(feature = "random", derive(GenerateRandom))]
67                 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
68         };
69
70         macro_rules! iter {
71                 ($t:expr, $f:expr) => {
72                         $t.iter_mut().for_each($f)
73                 };
74         }
75
76         match &mut input {
77                 syn::Item::Enum(e) => {
78                         iter!(e.attrs, wrap_attr);
79                         iter!(e.variants, |v| {
80                                 iter!(v.attrs, wrap_attr);
81                                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
82                         });
83
84                         let repr = args.repr.expect("missing repr for enum");
85
86                         if args.enumset {
87                                 let repr_str = repr.to_token_stream().to_string();
88
89                                 out.extend(quote! {
90                                         #[derive(EnumSetType)]
91                                         #[enumset(repr = #repr_str, serialize_as_map)]
92                                 })
93                         } else {
94                                 let has_payload = e
95                                         .variants
96                                         .iter()
97                                         .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
98                                         .is_some();
99
100                                 if has_payload {
101                                         let tag = args.tag.expect("missing tag for enum with payload");
102
103                                         out.extend(quote! {
104                                                 #[cfg_attr(feature = "serde", serde(tag = #tag))]
105                                         });
106
107                                         if let Some(content) = args.content {
108                                                 out.extend(quote! {
109                                                         #[cfg_attr(feature = "serde", serde(content = #content))]
110                                                 });
111                                         }
112                                 } else {
113                                         out.extend(quote! {
114                                                 #[derive(Copy, Eq)]
115                                         });
116                                 }
117
118                                 out.extend(quote! {
119                                         #[repr(#repr)]
120                                         #[derive(Clone, PartialEq)]
121                                 });
122
123                                 if !args.custom {
124                                         out.extend(quote! {
125                                                 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
126                                                 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
127                                         });
128                                 }
129                         }
130
131                         out.extend(quote! {
132                                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
133                         });
134                 }
135                 syn::Item::Struct(s) => {
136                         iter!(s.attrs, wrap_attr);
137                         iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
138
139                         out.extend(quote! {
140                                 #[derive(Clone, PartialEq)]
141                         });
142
143                         if !args.custom {
144                                 out.extend(quote! {
145                                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
146                                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
147                                 });
148                         }
149                 }
150                 _ => panic!("only enum and struct supported"),
151         }
152
153         out.extend(input.to_token_stream());
154         out.into()
155 }
156
157 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
158 #[darling(attributes(mt))]
159 #[darling(default)]
160 struct MtArgs {
161         #[darling(multiple)]
162         const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
163
164         #[darling(multiple)]
165         const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
166
167         size: Option<syn::Type>, // must implement MtCfg
168
169         len: Option<syn::Type>, // must implement MtCfg
170
171         zlib: bool,
172         zstd: bool, // TODO
173         default: bool, // type must implement Default
174 }
175
176 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
177
178 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
179         match fields {
180                 syn::Fields::Named(fs) => fs
181                         .named
182                         .iter()
183                         .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
184                         .collect(),
185                 syn::Fields::Unnamed(fs) => fs
186                         .unnamed
187                         .iter()
188                         .enumerate()
189                         .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
190                         .collect(),
191                 syn::Fields::Unit => Vec::new(),
192         }
193 }
194
195 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
196         match res {
197                 Ok(args) => {
198                         let mut code = body(&args);
199
200                         if args.zlib {
201                                 code = quote! {
202                                         let mut __writer = {
203                                                 let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
204                                                         __writer,
205                                                         mt_ser::flate2::Compression::default(),
206                                                 );
207                                                 let __writer = &mut __stream;
208                                                 #code
209                                                 __stream.finish()?
210                                         };
211                                 };
212                         }
213
214                         if let Some(size) = args.size {
215                                 code = quote! {
216                                         mt_ser::MtSerialize::mt_serialize::<#size>(&{
217                                                 let mut __buf = Vec::new();
218                                                 let __writer = &mut __buf;
219                                                 #code
220                                                 __buf
221                                         }, __writer)?;
222                                 };
223                         }
224
225                         for x in args.const_before.iter().rev() {
226                                 code = quote! {
227                                         #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
228                                         #code
229                                 }
230                         }
231
232                         for x in args.const_after.iter() {
233                                 code = quote! {
234                                         #code
235                                         #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
236                                 }
237                         }
238
239                         code
240                 }
241                 Err(e) => return e.write_errors(),
242         }
243 }
244
245 fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
246         match res {
247                 Ok(args) => {
248                         let mut code = body(&args);
249
250                         if args.zlib {
251                                 code = quote! {
252                                         {
253                                                 let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
254                                                         mt_ser::WrapRead(__reader));
255                                                 let __reader = &mut __owned_reader;
256
257                                                 #code
258                                         }
259                                 }
260                         }
261
262
263                         if let Some(size) = args.size {
264                                 code = quote! {
265                                         #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
266                                                 let mut __owned_reader = std::io::Read::take(
267                                                         mt_ser::WrapRead(__reader), size as u64);
268                                                 let __reader = &mut __owned_reader;
269
270                                                 #code
271                                         })
272                                 };
273                         }
274
275                         let impl_const = |value: TokStr| quote! {
276                                 {
277                                         fn deserialize_same_type<T: MtDeserialize>(
278                                                 _: &T,
279                                                 reader: &mut impl std::io::Read
280                                         ) -> Result<T, DeserializeError> {
281                                                 T::mt_deserialize::<mt_ser::DefCfg>(reader)
282                                         }
283
284                                         deserialize_same_type(&want, __reader)
285                                                 .and_then(|got| {
286                                                         if want == got {
287                                                                 #value
288                                                         } else {
289                                                                 Err(mt_ser::DeserializeError::InvalidConst(
290                                                                         Box::new(want), Box::new(got)
291                                                                 ))
292                                                         }
293                                                 })
294                                 }
295                         };
296
297                         for want in args.const_before.iter().rev() {
298                                 let imp = impl_const(code);
299                                 code = quote! {
300                                         {
301                                                 let want = #want;
302                                                 #imp
303                                         }
304                                 };
305                         }
306
307                         for want in args.const_after.iter() {
308                                 let imp = impl_const(quote! { Ok(value) });
309                                 code = quote! {
310                                         {
311                                                 let want = #want;
312                                                 #code.and_then(|value| { #imp })
313                                         }
314                                 };
315                         }
316
317                         code
318                 }
319                 Err(e) => return e.write_errors(),
320         }
321 }
322
323 fn serialize_fields(fields: &Fields) -> TokStr {
324         fields
325                 .iter()
326                 .map(|(ident, field)| {
327                         serialize_args(MtArgs::from_field(field), |args| {
328                                 let def = parse_quote! { mt_ser::DefCfg };
329                                 let len = args.len.as_ref().unwrap_or(&def);
330                                 quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
331                         })
332                 })
333                 .collect()
334 }
335
336 fn deserialize_fields(fields: &Fields) -> TokStr {
337         fields
338                 .iter()
339                 .map(|(ident, field)| {
340                         let code = deserialize_args(MtArgs::from_field(field), |args| {
341                                 let def = parse_quote! { mt_ser::DefCfg };
342                                 let len = args.len.as_ref().unwrap_or(&def);
343                                 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
344
345                                 if args.default {
346                                         code = quote! {
347                                                 mt_ser::OrDefault::or_default(#code)
348                                         };
349                                 }
350
351                                 code
352                         });
353
354                         quote! {
355                                 let #ident = #code?;
356                         }
357                 })
358                 .collect()
359 }
360
361 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
362         let ident_fn = match input {
363                 syn::Fields::Unnamed(_) => |f| {
364                         quote! {
365                                 mt_ser::paste::paste! { [<field_ #f>] }
366                         }
367                 },
368                 _ => |f| quote! { #f },
369         };
370
371         let fields = get_fields(input, ident_fn);
372         let fields_comma: TokStr = fields
373                 .iter()
374                 .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
375
376         let fields_struct = match input {
377                 syn::Fields::Named(_) => quote! { { #fields_comma } },
378                 syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
379                 syn::Fields::Unit => TokStr::new(),
380         };
381
382         (fields, fields_struct)
383 }
384
385 fn get_repr(input: &syn::DeriveInput) -> syn::Type {
386         input
387                 .attrs
388                 .iter()
389                 .find(|a| a.path.is_ident("repr"))
390                 .expect("missing repr")
391                 .parse_args()
392                 .expect("invalid repr")
393 }
394
395 #[proc_macro_derive(MtSerialize, attributes(mt))]
396 pub fn derive_serialize(input: TokenStream) -> TokenStream {
397         let input = parse_macro_input!(input as syn::DeriveInput);
398         let typename = &input.ident;
399
400         let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
401                 syn::Data::Enum(e) => {
402                         let repr = get_repr(&input);
403                         let variants: TokStr = e.variants
404                                 .iter()
405                                 .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| {
406                                         let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
407                                         let (fields, fields_struct) = get_fields_struct(&v.fields);
408
409                                         let code = serialize_args(MtArgs::from_variant(v), |_|
410                                                 serialize_fields(&fields));
411                                         let variant = &v.ident;
412
413                                         (
414                                                 parse_quote! { 1 + #discr },
415                                                 quote! {
416                                                         #before
417                                                         #typename::#variant #fields_struct => {
418                                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
419                                                                 #code
420                                                         }
421                                                 }
422                                         )
423                                 }).1;
424
425                         quote! {
426                                 match self {
427                                         #variants
428                                 }
429                         }
430                 }
431                 syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
432                 _ => {
433                         panic!("only enum and struct supported");
434                 }
435         });
436
437         quote! {
438                 #[automatically_derived]
439                 impl mt_ser::MtSerialize for #typename {
440                         fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
441                                 #code
442
443                                 Ok(())
444                         }
445                 }
446         }.into()
447 }
448
449 #[proc_macro_derive(MtDeserialize, attributes(mt))]
450 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
451         let input = parse_macro_input!(input as syn::DeriveInput);
452         let typename = &input.ident;
453
454         let code = deserialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
455                 syn::Data::Enum(e) => {
456                         let repr = get_repr(&input);
457                         let type_str = typename.to_string();
458
459                         let mut consts = TokStr::new();
460                         let mut arms = TokStr::new();
461                         let mut discr = parse_quote! { 0 };
462
463                         for v in e.variants.iter() {
464                                 discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
465
466                                 let ident = &v.ident;
467                                 let (fields, fields_struct) = get_fields_struct(&v.fields);
468                                 let code = deserialize_args(MtArgs::from_variant(v), |_| {
469                                         let fields_code = deserialize_fields(&fields);
470
471                                         quote! {
472                                                 #fields_code
473                                                 Ok(Self::#ident #fields_struct)
474                                         }
475                                 });
476
477                                 consts.extend(quote! {
478                                         const #ident: #repr = #discr;
479                                 });
480
481                                 arms.extend(quote! {
482                                         #ident => { #code }
483                                 });
484
485                                 discr = parse_quote! { 1 + #discr };
486                         }
487
488                         quote! {
489                                 #consts
490
491                                 match mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)? {
492                                         #arms
493                                         x => Err(mt_ser::DeserializeError::InvalidEnumVariant(#type_str, x as u64))
494                                 }
495                         }
496                 }
497                 syn::Data::Struct(s) => {
498                         let (fields, fields_struct) = get_fields_struct(&s.fields);
499                         let code = deserialize_fields(&fields);
500
501                         quote! {
502                                 #code
503                                 Ok(Self #fields_struct)
504                         }
505                 }
506                 _ => {
507                         panic!("only enum and struct supported");
508                 }
509         });
510
511         quote! {
512                 #[automatically_derived]
513                 impl mt_ser::MtDeserialize for #typename {
514                         #[allow(non_upper_case_globals)]
515                         fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
516                                 #code
517                         }
518                 }
519         }.into()
520 }