]> git.lizzy.rs Git - mt_ser.git/blob - derive/src/lib.rs
Initial commit
[mt_ser.git] / derive / src / lib.rs
1 use darling::FromMeta;
2 use proc_macro::{self, TokenStream};
3 use quote::{quote, ToTokens};
4 use syn::{parse_macro_input, parse_quote};
5
6 #[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
7 #[darling(rename_all = "snake_case")]
8 enum To {
9         Clt,
10         Srv,
11 }
12
13 #[derive(Debug, FromMeta)]
14 struct MacroArgs {
15         to: To,
16         repr: Option<syn::Type>,
17         tag: Option<String>,
18         content: Option<String>,
19         #[darling(default)]
20         enumset: bool,
21 }
22
23 fn wrap_attr(attr: &mut syn::Attribute) {
24         match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
25                 Some("mt") => {
26                         let path = attr.path.clone();
27                         let tokens = attr.tokens.clone();
28
29                         *attr = parse_quote! {
30                                 #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
31                         };
32                 }
33                 Some("serde") => {
34                         let path = attr.path.clone();
35                         let tokens = attr.tokens.clone();
36
37                         *attr = parse_quote! {
38                                 #[cfg_attr(feature = "serde", #path #tokens)]
39                         };
40                 }
41                 _ => {}
42         }
43 }
44
45 #[proc_macro_attribute]
46 pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
47         let item2 = item.clone();
48
49         let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
50         let mut input = parse_macro_input!(item2 as syn::Item);
51
52         let args = match MacroArgs::from_list(&attr_args) {
53                 Ok(v) => v,
54                 Err(e) => {
55                         return TokenStream::from(e.write_errors());
56                 }
57         };
58
59         let (serializer, deserializer) = match args.to {
60                 To::Clt => ("server", "client"),
61                 To::Srv => ("client", "server"),
62         };
63
64         let mut out = quote! {
65                 #[derive(Debug)]
66                 #[cfg_attr(feature = "random", derive(GenerateRandom))]
67                 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
68         };
69
70         macro_rules! iter {
71                 ($t:expr, $f:expr) => {
72                         $t.iter_mut().for_each($f)
73                 };
74         }
75
76         match &mut input {
77                 syn::Item::Enum(e) => {
78                         iter!(e.attrs, wrap_attr);
79                         iter!(e.variants, |v| {
80                                 iter!(v.attrs, wrap_attr);
81                                 iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
82                         });
83
84                         let repr = args.repr.expect("missing repr for enum");
85
86                         if args.enumset {
87                                 let repr_str = repr.to_token_stream().to_string();
88
89                                 out.extend(quote! {
90                                         #[derive(EnumSetType)]
91                                         #[enumset(repr = #repr_str, serialize_as_map)]
92                                 })
93                         } else {
94                                 let has_payload = e
95                                         .variants
96                                         .iter()
97                                         .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
98                                         .is_some();
99
100                                 if has_payload {
101                                         let tag = args.tag.expect("missing tag for enum with payload");
102
103                                         out.extend(quote! {
104                                                 #[derive(Clone)]
105                                                 #[cfg_attr(feature = "serde", serde(tag = #tag))]
106                                         });
107
108                                         if let Some(content) = args.content {
109                                                 out.extend(quote! {
110                                                         #[cfg_attr(feature = "serde", serde(content = #content))]
111                                                 });
112                                         }
113                                 } else {
114                                         out.extend(quote! {
115                                                 #[derive(Copy, Clone, PartialEq, Eq)]
116                                         });
117                                 }
118
119                                 out.extend(quote! {
120                                         #[repr(#repr)]
121                                         #[cfg_attr(feature = #serializer, derive(MtSerialize))]
122                                         #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
123                                 });
124                         }
125
126                         out.extend(quote! {
127                                 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
128                         });
129                 }
130                 syn::Item::Struct(s) => {
131                         iter!(s.attrs, wrap_attr);
132                         iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
133
134                         out.extend(quote! {
135                                 #[derive(Clone)]
136                                 #[cfg_attr(feature = #serializer, derive(MtSerialize))]
137                                 #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
138                         });
139                 }
140                 _ => panic!("only enum and struct supported"),
141         }
142
143         out.extend(input.to_token_stream());
144         out.into()
145 }
146
147 #[proc_macro_derive(MtSerialize, attributes(mt))]
148 pub fn derive_serialize(input: TokenStream) -> TokenStream {
149         let syn::DeriveInput { ident, .. } = parse_macro_input!(input);
150         let output = quote! {
151                 impl MtSerialize for #ident {
152                         fn mt_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<(), mt_data::SerializeError> {
153                                 Err(mt_data::SerializeError::Unimplemented)
154                         }
155                 }
156         };
157         output.into()
158 }
159
160 #[proc_macro_derive(MtDeserialize, attributes(mt))]
161 pub fn derive_deserialize(input: TokenStream) -> TokenStream {
162         let syn::DeriveInput { ident, .. } = parse_macro_input!(input);
163         quote! {
164                 impl MtDeserialize for #ident {
165                         fn mt_deserialize<R: std::io::Read>(reader: &mut R) -> Result<Self, mt_data::DeserializeError> {
166                                 Err(mt_data::DeserializeError::Unimplemented)
167                         }
168                 }
169         }.into()
170 }