]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
Support string repr
[mt_ser.git] / derive / src / lib.rs
1 use convert_case::{Case, Casing};
2 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
3 use proc_macro::TokenStream;
4 use proc_macro2::TokenStream as TokStr;
5 use quote::{quote, ToTokens};
6 use syn::{parse_macro_input, parse_quote};
7
8 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
9 #[darling(rename_all = "snake_case")]
10 enum To {
11     Clt,
12     Srv,
13 }
14
15 #[derive(Debug, FromMeta)]
16 struct MacroArgs {
17     to: To,
18     repr: Option<syn::Type>,
19     tag: Option<String>,
20     content: Option<String>,
21     #[darling(default)]
22     custom: bool,
23     #[darling(default)]
24     enumset: bool,
25 }
26
27 fn wrap_attr(attr: &mut syn::Attribute) {
28     let path = attr.path.clone();
29     let tokens = attr.tokens.clone();
30
31     match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
32         Some("mt") => {
33             *attr = parse_quote! {
34                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
35             };
36         }
37         Some("serde") => {
38             *attr = parse_quote! {
39                 #[cfg_attr(feature = "serde", #path #tokens)]
40             };
41         }
42         _ => {}
43     }
44 }
45
46 #[proc_macro_attribute]
47 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
48     let item2 = item.clone();
49
50     let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
51     let mut input = parse_macro_input!(item2 as syn::Item);
52
53     let args = match MacroArgs::from_list(&attr_args) {
54         Ok(v) => v,
55         Err(e) => {
56             return TokenStream::from(e.write_errors());
57         }
58     };
59
60     let (serializer, deserializer) = match args.to {
61         To::Clt => ("server", "client"),
62         To::Srv => ("client", "server"),
63     };
64
65     let mut out = quote! {
66         #[derive(Debug)]
67         #[cfg_attr(feature = "random", derive(GenerateRandom))]
68         #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69     };
70
71     macro_rules! iter {
72         ($t:expr, $f:expr) => {
73             $t.iter_mut().for_each($f)
74         };
75     }
76
77     match &mut input {
78         syn::Item::Enum(e) => {
79             iter!(e.attrs, wrap_attr);
80             iter!(e.variants, |v| {
81                 iter!(v.attrs, wrap_attr);
82                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
83             });
84
85             if args.enumset {
86                 let repr_str = args
87                     .repr
88                     .expect("missing repr for enum")
89                     .to_token_stream()
90                     .to_string();
91
92                 out.extend(quote! {
93                     #[derive(EnumSetType)]
94                     #[enumset(repr = #repr_str, serialize_as_map)]
95                 })
96             } else {
97                 let has_payload = e
98                     .variants
99                     .iter()
100                     .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
101                     .is_some();
102
103                 if has_payload {
104                     let tag = args.tag.expect("missing tag for enum with payload");
105
106                     out.extend(quote! {
107                         #[cfg_attr(feature = "serde", serde(tag = #tag))]
108                     });
109
110                     if let Some(content) = args.content {
111                         out.extend(quote! {
112                             #[cfg_attr(feature = "serde", serde(content = #content))]
113                         });
114                     }
115                 } else {
116                     out.extend(quote! {
117                         #[derive(Copy, Eq)]
118                     });
119                 }
120
121                 out.extend(quote! {
122                     #[derive(Clone, PartialEq)]
123                 });
124
125                 if !args.custom {
126                     out.extend(quote! {
127                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
128                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
129                     });
130                 }
131
132                 if let Some(repr) = args.repr {
133                     if repr == parse_quote! { str } {
134                         out.extend(quote! {
135                             #[mt(string_repr)]
136                         });
137                     } else {
138                         out.extend(quote! {
139                             #[repr(#repr)]
140                         });
141                     }
142                 } else if !args.custom {
143                     panic!("missing repr for enum");
144                 }
145             }
146
147             out.extend(quote! {
148                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
149             });
150         }
151         syn::Item::Struct(s) => {
152             iter!(s.attrs, wrap_attr);
153             iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
154
155             out.extend(quote! {
156                 #[derive(Clone, PartialEq)]
157             });
158
159             if !args.custom {
160                 out.extend(quote! {
161                     #[cfg_attr(feature = #serializer, derive(MtSerialize))]
162                     #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
163                 });
164             }
165         }
166         _ => panic!("only enum and struct supported"),
167     }
168
169     out.extend(input.to_token_stream());
170     out.into()
171 }
172
173 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
174 #[darling(attributes(mt))]
175 #[darling(default)]
176 struct MtArgs {
177     #[darling(multiple)]
178     const_before: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
179
180     #[darling(multiple)]
181     const_after: Vec<syn::Expr>, // must implement MtSerialize + MtDeserialize + PartialEq
182
183     size: Option<syn::Type>, // must implement MtCfg
184
185     len: Option<syn::Type>, // must implement MtCfg
186
187     zlib: bool,
188     zstd: bool,    // TODO
189     default: bool, // type must implement Default
190
191     string_repr: bool, // for enums
192 }
193
194 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
195
196 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
197     match fields {
198         syn::Fields::Named(fs) => fs
199             .named
200             .iter()
201             .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
202             .collect(),
203         syn::Fields::Unnamed(fs) => fs
204             .unnamed
205             .iter()
206             .enumerate()
207             .map(|(i, f)| (ident(syn::Index::from(i).to_token_stream()), f))
208             .collect(),
209         syn::Fields::Unit => Vec::new(),
210     }
211 }
212
213 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
214     match res {
215         Ok(args) => {
216             let mut code = body(&args);
217
218             if args.zlib {
219                 code = quote! {
220                     let mut __writer = {
221                         let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
222                             __writer,
223                             mt_ser::flate2::Compression::default(),
224                         );
225                         let __writer = &mut __stream;
226                         #code
227                         __stream.finish()?
228                     };
229                 };
230             }
231
232             if let Some(size) = args.size {
233                 code = quote! {
234                     mt_ser::MtSerialize::mt_serialize::<#size>(&{
235                         let mut __buf = Vec::new();
236                         let __writer = &mut __buf;
237                         #code
238                         __buf
239                     }, __writer)?;
240                 };
241             }
242
243             for x in args.const_before.iter().rev() {
244                 code = quote! {
245                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
246                     #code
247                 }
248             }
249
250             for x in args.const_after.iter() {
251                 code = quote! {
252                     #code
253                     #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
254                 }
255             }
256
257             code
258         }
259         Err(e) => return e.write_errors(),
260     }
261 }
262
263 fn deserialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
264     match res {
265         Ok(args) => {
266             let mut code = body(&args);
267
268             if args.zlib {
269                 code = quote! {
270                     {
271                         let mut __owned_reader = mt_ser::flate2::read::ZlibDecoder::new(
272                             mt_ser::WrapRead(__reader));
273                         let __reader = &mut __owned_reader;
274
275                         #code
276                     }
277                 }
278             }
279
280             if let Some(size) = args.size {
281                 code = quote! {
282                     #size::mt_deserialize::<DefCfg>(__reader).and_then(|size| {
283                         let mut __owned_reader = std::io::Read::take(
284                             mt_ser::WrapRead(__reader), size as u64);
285                         let __reader = &mut __owned_reader;
286
287                         #code
288                     })
289                 };
290             }
291
292             let impl_const = |value: TokStr| {
293                 quote! {
294                     {
295                         fn deserialize_same_type<T: MtDeserialize>(
296                             _: &T,
297                             reader: &mut impl std::io::Read
298                         ) -> Result<T, mt_ser::DeserializeError> {
299                             T::mt_deserialize::<mt_ser::DefCfg>(reader)
300                         }
301
302                         deserialize_same_type(&want, __reader)
303                             .and_then(|got| {
304                                 if want == got {
305                                     #value
306                                 } else {
307                                     Err(mt_ser::DeserializeError::InvalidConst(
308                                         Box::new(want), Box::new(got)
309                                     ))
310                                 }
311                             })
312                     }
313                 }
314             };
315
316             for want in args.const_before.iter().rev() {
317                 let imp = impl_const(code);
318                 code = quote! {
319                     {
320                         let want = #want;
321                         #imp
322                     }
323                 };
324             }
325
326             for want in args.const_after.iter() {
327                 let imp = impl_const(quote! { Ok(value) });
328                 code = quote! {
329                     {
330                         let want = #want;
331                         #code.and_then(|value| { #imp })
332                     }
333                 };
334             }
335
336             code
337         }
338         Err(e) => return e.write_errors(),
339     }
340 }
341
342 fn serialize_fields(fields: &Fields) -> TokStr {
343     fields
344         .iter()
345         .map(|(ident, field)| {
346             serialize_args(MtArgs::from_field(field), |args| {
347                 let def = parse_quote! { mt_ser::DefCfg };
348                 let len = args.len.as_ref().unwrap_or(&def);
349                 quote! { mt_ser::MtSerialize::mt_serialize::<#len>(#ident, __writer)?; }
350             })
351         })
352         .collect()
353 }
354
355 fn deserialize_fields(fields: &Fields) -> TokStr {
356     fields
357         .iter()
358         .map(|(ident, field)| {
359             let code = deserialize_args(MtArgs::from_field(field), |args| {
360                 let def = parse_quote! { mt_ser::DefCfg };
361                 let len = args.len.as_ref().unwrap_or(&def);
362                 let mut code = quote! { mt_ser::MtDeserialize::mt_deserialize::<#len>(__reader) };
363
364                 if args.default {
365                     code = quote! {
366                         mt_ser::OrDefault::or_default(#code)
367                     };
368                 }
369
370                 code
371             });
372
373             quote! {
374                 let #ident = #code?;
375             }
376         })
377         .collect()
378 }
379
380 fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) {
381     let ident_fn = match input {
382         syn::Fields::Unnamed(_) => |f| {
383             quote! {
384                 mt_ser::paste::paste! { [<field_ #f>] }
385             }
386         },
387         _ => |f| quote! { #f },
388     };
389
390     let fields = get_fields(input, ident_fn);
391     let fields_comma: TokStr = fields
392         .iter()
393         .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
394
395     let fields_struct = match input {
396         syn::Fields::Named(_) => quote! { { #fields_comma } },
397         syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
398         syn::Fields::Unit => TokStr::new(),
399     };
400
401     (fields, fields_struct)
402 }
403
404 fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type {
405     if args.string_repr {
406         parse_quote! { &str }
407     } else {
408         input
409             .attrs
410             .iter()
411             .find(|a| a.path.is_ident("repr"))
412             .expect("missing repr")
413             .parse_args()
414             .expect("invalid repr")
415     }
416 }
417
418 fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) {
419     let mut discr = parse_quote! { 0 };
420
421     for v in e.variants.iter() {
422         discr = if args.string_repr {
423             let lit = v.ident.to_string().to_case(Case::Snake);
424             parse_quote! { #lit }
425         } else {
426             v.discriminant.clone().map(|x| x.1).unwrap_or(discr)
427         };
428
429         f(&v, &discr);
430
431         discr = parse_quote! { 1 + #discr };
432     }
433 }
434
435 #[proc_macro_derive(MtSerialize, attributes(mt))]
436 pub fn derive_serialize(input: TokenStream) -> TokenStream {
437     let input = parse_macro_input!(input as syn::DeriveInput);
438     let typename = &input.ident;
439
440     let code = serialize_args(MtArgs::from_derive_input(&input), |args| {
441         match &input.data {
442             syn::Data::Enum(e) => {
443                 let repr = get_repr(&input, &args);
444                 let mut variants = TokStr::new();
445
446                 iter_variants(&e, &args, |v, discr| {
447                     let (fields, fields_struct) = get_fields_struct(&v.fields);
448                     let code =
449                         serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields));
450                     let ident = &v.ident;
451
452                     variants.extend(quote! {
453                                         #typename::#ident #fields_struct => {
454                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
455                                                 #code
456                                         }
457                                 });
458                 });
459
460                 quote! {
461                     match self {
462                         #variants
463                     }
464                 }
465             }
466             syn::Data::Struct(s) => {
467                 serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f }))
468             }
469             _ => {
470                 panic!("only enum and struct supported");
471             }
472         }
473     });
474
475     quote! {
476                 #[automatically_derived]
477                 impl mt_ser::MtSerialize for #typename {
478                         fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
479                                 #code
480
481                                 Ok(())
482                         }
483                 }
484         }.into()
485 }
486
487 #[proc_macro_derive(MtDeserialize, attributes(mt))]
488 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
489     let input = parse_macro_input!(input as syn::DeriveInput);
490     let typename = &input.ident;
491
492     let code = deserialize_args(MtArgs::from_derive_input(&input), |args| {
493         match &input.data {
494             syn::Data::Enum(e) => {
495                 let repr = get_repr(&input, &args);
496
497                 let mut consts = TokStr::new();
498                 let mut arms = TokStr::new();
499
500                 iter_variants(&e, &args, |v, discr| {
501                     let ident = &v.ident;
502                     let (fields, fields_struct) = get_fields_struct(&v.fields);
503                     let code = deserialize_args(MtArgs::from_variant(v), |_| {
504                         let fields_code = deserialize_fields(&fields);
505
506                         quote! {
507                             #fields_code
508                             Ok(Self::#ident #fields_struct)
509                         }
510                     });
511
512                     consts.extend(quote! {
513                         const #ident: #repr = #discr;
514                     });
515
516                     arms.extend(quote! {
517                         #ident => { #code }
518                     });
519                 });
520
521                 let type_str = typename.to_string();
522                 let discr_match = if args.string_repr {
523                     quote! {
524                         let __discr = String::mt_deserialize::<DefCfg>(__reader)?;
525                         match __discr.as_str()
526                     }
527                 } else {
528                     quote! {
529                         let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?;
530                         match __discr
531                     }
532                 };
533
534                 quote! {
535                     #consts
536
537                     #discr_match {
538                         #arms
539                         _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr)))
540                     }
541                 }
542             }
543             syn::Data::Struct(s) => {
544                 let (fields, fields_struct) = get_fields_struct(&s.fields);
545                 let code = deserialize_fields(&fields);
546
547                 quote! {
548                     #code
549                     Ok(Self #fields_struct)
550                 }
551             }
552             _ => {
553                 panic!("only enum and struct supported");
554             }
555         }
556     });
557
558     quote! {
559                 #[automatically_derived]
560                 impl mt_ser::MtDeserialize for #typename {
561                         #[allow(non_upper_case_globals)]
562                         fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
563                                 #code
564                         }
565                 }
566         }.into()
567 }