]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
dd82c23cf3e60e47a8ad066b4f71a212ff7831c3
[mt_ser.git] / derive / src / lib.rs
1 use darling::{FromDeriveInput, FromField, FromMeta, FromVariant};
2 use proc_macro::TokenStream;
3 use proc_macro2::TokenStream as TokStr;
4 use quote::{quote, ToTokens};
5 use syn::{parse_macro_input, parse_quote};
6
7 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
8 #[darling(rename_all = "snake_case")]
9 enum To {
10         Clt,
11         Srv,
12 }
13
14 #[derive(Debug, FromMeta)]
15 struct MacroArgs {
16         to: To,
17         repr: Option<syn::Type>,
18         tag: Option<String>,
19         content: Option<String>,
20         #[darling(default)]
21         custom: bool,
22         #[darling(default)]
23         enumset: bool,
24 }
25
26 fn wrap_attr(attr: &mut syn::Attribute) {
27         let path = attr.path.clone();
28         let tokens = attr.tokens.clone();
29
30         match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
31                 Some("mt") => {
32                         *attr = parse_quote! {
33                                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
34                         };
35                 }
36                 Some("serde") => {
37                         *attr = parse_quote! {
38                                 #[cfg_attr(feature = "serde", #path #tokens)]
39                         };
40                 }
41                 _ => {}
42         }
43 }
44
45 #[proc_macro_attribute]
46 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
47         let item2 = item.clone();
48
49         let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
50         let mut input = parse_macro_input!(item2 as syn::Item);
51
52         let args = match MacroArgs::from_list(&attr_args) {
53                 Ok(v) => v,
54                 Err(e) => {
55                         return TokenStream::from(e.write_errors());
56                 }
57         };
58
59         let (serializer, deserializer) = match args.to {
60                 To::Clt => ("server", "client"),
61                 To::Srv => ("client", "server"),
62         };
63
64         let mut out = quote! {
65                 #[derive(Debug)]
66                 #[cfg_attr(feature = "random", derive(GenerateRandom))]
67                 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
68         };
69
70         macro_rules! iter {
71                 ($t:expr, $f:expr) => {
72                         $t.iter_mut().for_each($f)
73                 };
74         }
75
76         match &mut input {
77                 syn::Item::Enum(e) => {
78                         iter!(e.attrs, wrap_attr);
79                         iter!(e.variants, |v| {
80                                 iter!(v.attrs, wrap_attr);
81                                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
82                         });
83
84                         let repr = args.repr.expect("missing repr for enum");
85
86                         if args.enumset {
87                                 let repr_str = repr.to_token_stream().to_string();
88
89                                 out.extend(quote! {
90                                         #[derive(EnumSetType)]
91                                         #[enumset(repr = #repr_str, serialize_as_map)]
92                                 })
93                         } else {
94                                 let has_payload = e
95                                         .variants
96                                         .iter()
97                                         .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
98                                         .is_some();
99
100                                 if has_payload {
101                                         let tag = args.tag.expect("missing tag for enum with payload");
102
103                                         out.extend(quote! {
104                                                 #[cfg_attr(feature = "serde", serde(tag = #tag))]
105                                         });
106
107                                         if let Some(content) = args.content {
108                                                 out.extend(quote! {
109                                                         #[cfg_attr(feature = "serde", serde(content = #content))]
110                                                 });
111                                         }
112                                 } else {
113                                         out.extend(quote! {
114                                                 #[derive(Copy, Eq)]
115                                         });
116                                 }
117
118                                 out.extend(quote! {
119                                         #[repr(#repr)]
120                                         #[derive(Clone, PartialEq)]
121                                 });
122
123                                 if !args.custom {
124                                         out.extend(quote! {
125                                                 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
126                                                 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
127                                         });
128                                 }
129                         }
130
131                         out.extend(quote! {
132                                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
133                         });
134                 }
135                 syn::Item::Struct(s) => {
136                         iter!(s.attrs, wrap_attr);
137                         iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
138
139                         out.extend(quote! {
140                                 #[derive(Clone, PartialEq)]
141                         });
142
143                         if !args.custom {
144                                 out.extend(quote! {
145                                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
146                                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
147                                 });
148                         }
149                 }
150                 _ => panic!("only enum and struct supported"),
151         }
152
153         out.extend(input.to_token_stream());
154         out.into()
155 }
156
157 #[derive(Debug, Default, FromDeriveInput, FromVariant, FromField)]
158 #[darling(attributes(mt))]
159 #[darling(default)]
160 struct MtArgs {
161         const8: Option<u8>,
162         const16: Option<u16>,
163         const32: Option<u32>,
164         const64: Option<u64>,
165         size8: bool,
166         size16: bool,
167         size32: bool,
168         size64: bool,
169         len0: bool,
170         len8: bool,
171         len16: bool,
172         len32: bool,
173         len64: bool,
174         utf16: bool,
175         zlib: bool,
176         zstd: bool, // TODO
177         default: bool,
178 }
179
180 fn get_cfg(args: &MtArgs) -> syn::Type {
181         let mut ty: syn::Type = parse_quote! { mt_ser::DefCfg  };
182
183         if args.len0 {
184                 ty = parse_quote! { () };
185         }
186
187         macro_rules! impl_len {
188                 ($name:ident, $T:ty) => {
189                         if args.$name {
190                                 ty = parse_quote! { $T  };
191                         }
192                 };
193         }
194
195         impl_len!(len8, u8);
196         impl_len!(len16, u16);
197         impl_len!(len32, u32);
198         impl_len!(len64, u64);
199
200         if args.utf16 {
201                 ty = parse_quote! { mt_ser::Utf16<#ty> };
202         }
203
204         ty
205 }
206
207 type Fields<'a> = Vec<(TokStr, &'a syn::Field)>;
208
209 fn get_fields(fields: &syn::Fields, ident: impl Fn(TokStr) -> TokStr) -> Fields {
210         match fields {
211                 syn::Fields::Named(fs) => fs
212                         .named
213                         .iter()
214                         .map(|f| (ident(f.ident.as_ref().unwrap().to_token_stream()), f))
215                         .collect(),
216                 syn::Fields::Unnamed(fs) => fs
217                         .unnamed
218                         .iter()
219                         .enumerate()
220                         .map(|(i, f)| (ident(i.to_string().to_token_stream()), f))
221                         .collect(),
222                 syn::Fields::Unit => Vec::new(),
223         }
224 }
225
226 fn serialize_args(res: darling::Result<MtArgs>, body: impl FnOnce(&MtArgs) -> TokStr) -> TokStr {
227         match res {
228                 Ok(args) => {
229                         let mut code = TokStr::new();
230
231                         macro_rules! impl_const {
232                                 ($name:ident) => {
233                                         if let Some(x) = args.$name {
234                                                 code.extend(quote! {
235                                                         #x.mt_serialize::<mt_ser::DefCfg>(__writer)?;
236                                                 });
237                                         }
238                                 };
239                         }
240
241                         impl_const!(const8);
242                         impl_const!(const16);
243                         impl_const!(const32);
244                         impl_const!(const64);
245
246                         code.extend(body(&args));
247
248                         if args.zlib {
249                                 code = quote! {
250                                         let mut __writer = {
251                                                 let mut __stream = mt_ser::flate2::write::ZlibEncoder::new(
252                                                         __writer,
253                                                         mt_ser::flate2::Compression::default(),
254                                                 );
255                                                 let __writer = &mut __stream;
256                                                 #code
257                                                 __stream.finish()?
258                                         };
259                                 };
260                         }
261
262                         macro_rules! impl_size {
263                                 ($name:ident, $T:ty) => {
264                                         if args.$name {
265                                                 code = quote! {
266                                                                 mt_ser::MtSerialize::mt_serialize::<$T>(&{
267                                                                         let mut __buf = Vec::new();
268                                                                         let __writer = &mut __buf;
269                                                                         #code
270                                                                         __buf
271                                                                 }, __writer)?;
272                                                 };
273                                         }
274                                 };
275                         }
276
277                         impl_size!(size8, u8);
278                         impl_size!(size16, u16);
279                         impl_size!(size32, u32);
280                         impl_size!(size64, u64);
281
282                         code
283                 }
284                 Err(e) => return e.write_errors(),
285         }
286 }
287
288 fn serialize_fields(fields: &Fields) -> TokStr {
289         fields
290                 .iter()
291                 .map(|(ident, field)| {
292                         serialize_args(MtArgs::from_field(field), |args| {
293                                 let cfg = get_cfg(args);
294                                 quote! { mt_ser::MtSerialize::mt_serialize::<#cfg>(#ident, __writer)?; }
295                         })
296                 })
297                 .collect()
298 }
299
300 #[proc_macro_derive(MtSerialize, attributes(mt))]
301 pub fn derive_serialize(input: TokenStream) -> TokenStream {
302         let input = parse_macro_input!(input as syn::DeriveInput);
303         let typename = &input.ident;
304
305         let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data {
306                 syn::Data::Enum(e) => {
307                         let repr: syn::Type = input
308                                 .attrs
309                                 .iter()
310                                 .find(|a| a.path.is_ident("repr"))
311                                 .expect("missing repr")
312                                 .parse_args()
313                                 .expect("invalid repr");
314
315                         let variants: TokStr = e.variants
316                                 .iter()
317                                 .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| {
318                                         let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr);
319
320                                         let ident_fn = match &v.fields {
321                                                 syn::Fields::Unnamed(_) => |f| quote! {
322                                                         mt_ser::paste::paste! { [<field_ #f>] }
323                                                 },
324                                                 _ => |f| quote! { #f },
325                                         };
326
327                                         let fields = get_fields(&v.fields, ident_fn);
328                                         let fields_comma: TokStr = fields.iter()
329                                                 .rfold(TokStr::new(), |after, (ident, _)| quote! { #ident, #after });
330
331                                         let destruct = match &v.fields {
332                                                 syn::Fields::Named(_) => quote! { { #fields_comma } },
333                                                 syn::Fields::Unnamed(_) => quote! { ( #fields_comma ) },
334                                                 syn::Fields::Unit => TokStr::new(),
335                                         };
336
337                                         let code = serialize_args(MtArgs::from_variant(v), |_|
338                                                 serialize_fields(&fields));
339                                         let variant = &v.ident;
340
341                                         (
342                                                 parse_quote! { 1 + #discr },
343                                                 quote! {
344                                                         #before
345                                                         #typename::#variant #destruct => {
346                                                                 mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?;
347                                                                 #code
348                                                         }
349                                                 }
350                                         )
351                                 }).1;
352
353                         quote! {
354                                 match self {
355                                         #variants
356                                 }
357                         }
358                 }
359                 syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })),
360                 _ => {
361                         panic!("only enum and struct supported");
362                 }
363         });
364
365         quote! {
366                 #[automatically_derived]
367                 impl mt_ser::MtSerialize for #typename {
368                         fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> {
369                                 #code
370
371                                 Ok(())
372                         }
373                 }
374         }.into()
375 }
376
377 #[proc_macro_derive(MtDeserialize, attributes(mt))]
378 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
379         let syn::DeriveInput {
380                 ident: typename, ..
381         } = parse_macro_input!(input);
382         quote! {
383                 #[automatically_derived]
384                 impl mt_ser::MtDeserialize for #typename {
385                         fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> {
386                                 Err(mt_ser::DeserializeError::Unimplemented)
387                         }
388                 }
389         }.into()
390 }