]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
Implement deserialize for basic types
[mt_ser.git] / derive / src / lib.rs
1 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
2 use proc_macro::TokenStream;
3 use proc_macro2::TokenStream as TokStr;
4 use quote::{quote, ToTokens};
5 use syn::{parse_macro_input, parse_quote};
6
7 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
8 #[darling(rename_all = "snake_case")]
9 enum To {
10         Clt,
11         Srv,
12 }
13
14 #[derive(Debug, FromMeta)]
15 struct MacroArgs {
16         to: To,
17         repr: Option<syn::Type>,
18         tag: Option<String>,
19         content: Option<String>,
20         #[darling(default)]
21         custom: bool,
22         #[darling(default)]
23         enumset: bool,
24 }
25
26 fn wrap_attr(attr: &mut syn::Attribute) {
27         match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
28                 Some("mt") => {
29                         let path = attr.path.clone();
30                         let tokens = attr.tokens.clone();
31
32                         *attr = parse_quote! {
33                                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
34                         };
35                 }
36                 Some("serde") => {
37                         let path = attr.path.clone();
38                         let tokens = attr.tokens.clone();
39
40                         *attr = parse_quote! {
41                                 #[cfg_attr(feature = "serde", #path #tokens)]
42                         };
43                 }
44                 _ => {}
45         }
46 }
47
48 #[proc_macro_attribute]
49 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
50         let item2 = item.clone();
51
52         let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
53         let mut input = parse_macro_input!(item2 as syn::Item);
54
55         let args = match MacroArgs::from_list(&attr_args) {
56                 Ok(v) => v,
57                 Err(e) => {
58                         return TokenStream::from(e.write_errors());
59                 }
60         };
61
62         let (serializer, deserializer) = match args.to {
63                 To::Clt => ("server", "client"),
64                 To::Srv => ("client", "server"),
65         };
66
67         let mut out = quote! {
68                 #[derive(Debug)]
69                 #[cfg_attr(feature = "random", derive(GenerateRandom))]
70                 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
71         };
72
73         macro_rules! iter {
74                 ($t:expr, $f:expr) => {
75                         $t.iter_mut().for_each($f)
76                 };
77         }
78
79         match &mut input {
80                 syn::Item::Enum(e) => {
81                         iter!(e.attrs, wrap_attr);
82                         iter!(e.variants, |v| {
83                                 iter!(v.attrs, wrap_attr);
84                                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
85                         });
86
87                         let repr = args.repr.expect("missing repr for enum");
88
89                         if args.enumset {
90                                 let repr_str = repr.to_token_stream().to_string();
91
92                                 out.extend(quote! {
93                                         #[derive(EnumSetType)]
94                                         #[enumset(repr = #repr_str, serialize_as_map)]
95                                 })
96                         } else {
97                                 let has_payload = e
98                                         .variants
99                                         .iter()
100                                         .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
101                                         .is_some();
102
103                                 if has_payload {
104                                         let tag = args.tag.expect("missing tag for enum with payload");
105
106                                         out.extend(quote! {
107                                                 #[cfg_attr(feature = "serde", serde(tag = #tag))]
108                                         });
109
110                                         if let Some(content) = args.content {
111                                                 out.extend(quote! {
112                                                         #[cfg_attr(feature = "serde", serde(content = #content))]
113                                                 });
114                                         }
115                                 } else {
116                                         out.extend(quote! {
117                                                 #[derive(Copy, Eq)]
118                                         });
119                                 }
120
121                                 out.extend(quote! {
122                                         #[repr(#repr)]
123                                         #[derive(Clone, PartialEq)]
124                                 });
125
126                                 if !args.custom {
127                                         out.extend(quote! {
128                                                 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
129                                                 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
130                                         });
131                                 }
132                         }
133
134                         out.extend(quote! {
135                                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
136                         });
137                 }
138                 syn::Item::Struct(s) => {
139                         iter!(s.attrs, wrap_attr);
140                         iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
141
142                         out.extend(quote! {
143                                 #[derive(Clone, PartialEq)]
144                         });
145
146                         if !args.custom {
147                                 out.extend(quote! {
148                                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
149                                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
150                                 });
151                         }
152                 }
153                 _ => panic!("only enum and struct supported"),
154         }
155
156         out.extend(input.to_token_stream());
157         out.into()
158 }
159
160 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
161 #[darling(attributes(mt))]
162 #[darling(default)]
163 struct MtArgs {
164         const8: Option<u8>,
165         const16: Option<u16>,
166         const32: Option<u32>,
167         const64: Option<u64>,
168         size8: bool,
169         size16: bool,
170         size32: bool,
171         size64: bool,
172         len0: bool,
173         len8: bool,
174         len16: bool,
175         len32: bool,
176         len64: bool,
177         utf16: bool,
178         zlib: bool,
179         zstd: bool,
180         default: bool,
181 }
182
183 fn get_cfg(args: &MtArgs) -> syn::Type {
184         let mut ty: syn::Type = parse_quote! { mt_data::DefCfg  };
185
186         if args.len0 {
187                 ty = parse_quote! { mt_data::NoLen };
188         }
189
190         macro_rules! impl_len {
191                 ($name:ident, $T:ty) => {
192                         if args.$name {
193                                 ty = parse_quote! { $T  };
194                         }
195                 };
196         }
197
198         impl_len!(len8, u8);
199         impl_len!(len16, u16);
200         impl_len!(len32, u32);
201         impl_len!(len64, u64);
202
203         if args.utf16 {
204                 ty = parse_quote! { mt_data::Utf16<#ty> };
205         }
206
207         ty
208 }
209
210 /*
211 fn is_ident(path: &syn::Path, ident: &str) -> bool {
212         matches!(path.segments.first().map(|p| &p.ident), Some(idt) if idt == ident)
213 }
214
215 fn get_type_generics<const N: usize>(path: &syn::Path) -> Option<[&syn::Type; N]> {
216         use syn::{AngleBracketedGenericArguments as Args, PathArguments::AngleBracketed};
217
218         path.segments
219                 .first()
220                 .map(|seg| match &seg.arguments {
221                         AngleBracketed(Args { args, .. }) => args
222                                 .iter()
223                                 .flat_map(|arg| match arg {
224                                         syn::GenericArgument::Type(t) => Some(t),
225                                         _ => None,
226                                 })
227                                 .collect::<Vec<_>>()
228                                 .try_into()
229                                 .ok(),
230                         _ => None,
231                 })
232                 .flatten()
233 }
234 */
235
236 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
237
238 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
239         match fields {
240                 syn::Fields::Named(fs) => fs
241                         .named
242                         .iter()
243                         .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
244                         .collect(),
245                 syn::Fields::Unnamed(fs) => fs
246                         .unnamed
247                         .iter()
248                         .enumerate()
249                         .map(|(i, f)| (ident(i.to_string().to_token_stream()), f))
250                         .collect(),
251                 syn::Fields::Unit => Vec::new(),
252         }
253 }
254
255 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
256         match res {
257                 Ok(args) => {
258                         let mut code = TokStr::new();
259
260                         macro_rules! impl_const {
261                                 ($name:ident) => {
262                                         if let Some(x) = args.$name {
263                                                 code.extend(quote! {
264                                                         #x.mt_serialize::<mt_data::DefCfg>(__writer)?;
265                                                 });
266                                         }
267                                 };
268                         }
269
270                         impl_const!(const8);
271                         impl_const!(const16);
272                         impl_const!(const32);
273                         impl_const!(const64);
274
275                         code.extend(body(&args));
276
277                         if args.zlib {
278                                 code = quote! {
279                                         let mut __writer = {
280                                                 let mut __stream = mt_data::flate2::write::ZlibEncoder::new(__writer, flate2::Compression::default());
281                                                 let __writer = &mut __stream;
282                                                 #code
283                                                 __stream.finish()?
284                                         };
285                                 };
286                         }
287
288                         macro_rules! impl_size {
289                                 ($name:ident, $T:ty) => {
290                                         if args.$name {
291                                                 code = quote! {
292                                                                 mt_data::MtSerialize::mt_serialize::<$T>(&{
293                                                                         let mut __buf = Vec::new();
294                                                                         let __writer = &mut __buf;
295                                                                         #code
296                                                                         __buf
297                                                                 }, __writer)?;
298                                                 };
299                                         }
300                                 };
301                         }
302
303                         impl_size!(size8, u8);
304                         impl_size!(size16, u16);
305                         impl_size!(size32, u32);
306                         impl_size!(size64, u64);
307
308                         code
309                 }
310                 Err(e) => return e.write_errors(),
311         }
312 }
313
314 fn serialize_fields(fields: &Fields) -> TokStr {
315         fields
316                 .iter()
317                 .map(|(ident, field)| {
318                         serialize_args(MtArgs::from_field(field), |args| {
319                                 let cfg = get_cfg(args);
320                                 quote! { mt_data::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; }
321                         })
322                 })
323                 .collect()
324 }
325
326 #[proc_macro_derive(MtSerialize, attributes(mt))]
327 pub fn derive_serialize(input: TokenStream) -> TokenStream {
328         let input = parse_macro_input!(input as syn::DeriveInput);
329         let typename = &input.ident;
330
331         let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
332                 syn::Data::Enum(e) => {
333                         let repr: syn::Type = input
334                                 .attrs
335                                 .iter()
336                                 .find(|a| a.path.is_ident("repr"))
337                                 .expect("missing repr")
338                                 .parse_args()
339                                 .expect("invalid repr");
340
341                         let variants: TokStr = e.variants
342                                 .iter()
343                                 .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| {
344                                         let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
345
346                                         let ident_fn = match &v.fields {
347                                                 syn::Fields::Unnamed(_) => |f| quote! {
348                                                         mt_data::paste::paste! { [<field_ #f>] }
349                                                 },
350                                                 _ => |f| quote! { #f },
351                                         };
352
353                                         let fields = get_fields(&v.fields, ident_fn);
354                                         let fields_comma: TokStr = fields.iter()
355                                                 .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
356
357                                         let destruct = match &v.fields {
358                                                 syn::Fields::Named(_) => quote! { { #fields_comma } },
359                                                 syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
360                                                 syn::Fields::Unit => TokStr::new(),
361                                         };
362
363                                         let code = serialize_args(MtArgs::from_variant(v), |_|
364                                                 serialize_fields(&fields));
365                                         let variant = &v.ident;
366
367                                         (
368                                                 parse_quote! { 1 + #discr },
369                                                 quote! {
370                                                         #before
371                                                         #typename::#variant #destruct => {
372                                                                 mt_data::MtSerialize::mt_serialize::<mt_data::DefCfg>(&((#discr) as #repr), __writer)?;
373                                                                 #code
374                                                         }
375                                                 }
376                                         )
377                                 }).1;
378
379                         quote! {
380                                 match self {
381                                         #variants
382                                 }
383                         }
384                 }
385                 syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
386                 _ => {
387                         panic!("only enum and struct supported");
388                 }
389         });
390
391         quote! {
392                 #[automatically_derived]
393                 impl mt_data::MtSerialize for #typename {
394                         fn mt_serialize<C: MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_data::SerializeError> {
395                                 #code
396
397                                 Ok(())
398                         }
399                 }
400         }.into()
401 }
402
403 #[proc_macro_derive(MtDeserialize, attributes(mt))]
404 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
405         let syn::DeriveInput {
406                 ident: typename, ..
407         } = parse_macro_input!(input);
408         quote! {
409                 #[automatically_derived]
410                 impl mt_data::MtDeserialize for #typename {
411                         fn mt_deserialize<C: MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_data::DeserializeError> {
412                                 Err(mt_data::DeserializeError::Unimplemented)
413                         }
414                 }
415         }.into()
416 }