]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/serialize.rs
783b47a49e5d7dd9cf2a6af9dfc292945f6f6a4a
[rust.git] / compiler / rustc_macros / src / serialize.rs
1 use proc_macro2::TokenStream;
2 use quote::{quote, quote_spanned};
3 use syn::parse_quote;
4 use syn::spanned::Spanned;
5
6 pub fn type_decodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
7     let decoder_ty = quote! { __D };
8     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
9         s.add_impl_generic(parse_quote! { 'tcx });
10     }
11     s.add_impl_generic(parse_quote! {#decoder_ty: ::rustc_middle::ty::codec::TyDecoder<'tcx>});
12     s.add_bounds(synstructure::AddBounds::Generics);
13
14     decodable_body(s, decoder_ty)
15 }
16
17 pub fn meta_decodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
18     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
19         s.add_impl_generic(parse_quote! { 'tcx });
20     }
21     s.add_impl_generic(parse_quote! { '__a });
22     let decoder_ty = quote! { DecodeContext<'__a, 'tcx> };
23     s.add_bounds(synstructure::AddBounds::Generics);
24
25     decodable_body(s, decoder_ty)
26 }
27
28 pub fn decodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
29     let decoder_ty = quote! { __D };
30     s.add_impl_generic(parse_quote! {#decoder_ty: ::rustc_serialize::Decoder});
31     s.add_bounds(synstructure::AddBounds::Generics);
32
33     decodable_body(s, decoder_ty)
34 }
35
36 fn decodable_body(
37     s: synstructure::Structure<'_>,
38     decoder_ty: TokenStream,
39 ) -> proc_macro2::TokenStream {
40     if let syn::Data::Union(_) = s.ast().data {
41         panic!("cannot derive on union")
42     }
43     let ty_name = s.ast().ident.to_string();
44     let decode_body = match s.variants() {
45         [vi] => vi.construct(|field, _index| decode_field(field)),
46         variants => {
47             let match_inner: TokenStream = variants
48                 .iter()
49                 .enumerate()
50                 .map(|(idx, vi)| {
51                     let construct = vi.construct(|field, _index| decode_field(field));
52                     quote! { #idx => { #construct } }
53                 })
54                 .collect();
55             let message = format!(
56                 "invalid enum variant tag while decoding `{}`, expected 0..{}",
57                 ty_name,
58                 variants.len()
59             );
60             quote! {
61                 ::rustc_serialize::Decoder::read_enum_variant(
62                     __decoder,
63                     |__decoder, __variant_idx| {
64                         match __variant_idx {
65                             #match_inner
66                             _ => panic!(#message),
67                         }
68                     })
69             }
70         }
71     };
72
73     s.bound_impl(
74         quote!(::rustc_serialize::Decodable<#decoder_ty>),
75         quote! {
76             fn decode(__decoder: &mut #decoder_ty) -> Self {
77                 #decode_body
78             }
79         },
80     )
81 }
82
83 fn decode_field(field: &syn::Field) -> proc_macro2::TokenStream {
84     let field_span = field.ident.as_ref().map_or(field.ty.span(), |ident| ident.span());
85
86     let decode_inner_method = if let syn::Type::Reference(_) = field.ty {
87         quote! { ::rustc_middle::ty::codec::RefDecodable::decode }
88     } else {
89         quote! { ::rustc_serialize::Decodable::decode }
90     };
91     let __decoder = quote! { __decoder };
92     // Use the span of the field for the method call, so
93     // that backtraces will point to the field.
94     quote_spanned! {field_span=> #decode_inner_method(#__decoder) }
95 }
96
97 pub fn type_encodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
98     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
99         s.add_impl_generic(parse_quote! {'tcx});
100     }
101     let encoder_ty = quote! { __E };
102     s.add_impl_generic(parse_quote! {#encoder_ty: ::rustc_middle::ty::codec::TyEncoder<'tcx>});
103     s.add_bounds(synstructure::AddBounds::Generics);
104
105     encodable_body(s, encoder_ty, false)
106 }
107
108 pub fn meta_encodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
109     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
110         s.add_impl_generic(parse_quote! {'tcx});
111     }
112     s.add_impl_generic(parse_quote! { '__a });
113     let encoder_ty = quote! { EncodeContext<'__a, 'tcx> };
114     s.add_bounds(synstructure::AddBounds::Generics);
115
116     encodable_body(s, encoder_ty, true)
117 }
118
119 pub fn encodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
120     let encoder_ty = quote! { __E };
121     s.add_impl_generic(parse_quote! { #encoder_ty: ::rustc_serialize::Encoder});
122     s.add_bounds(synstructure::AddBounds::Generics);
123
124     encodable_body(s, encoder_ty, false)
125 }
126
127 fn encodable_body(
128     mut s: synstructure::Structure<'_>,
129     encoder_ty: TokenStream,
130     allow_unreachable_code: bool,
131 ) -> proc_macro2::TokenStream {
132     if let syn::Data::Union(_) = s.ast().data {
133         panic!("cannot derive on union")
134     }
135
136     s.bind_with(|binding| {
137         // Handle the lack of a blanket reference impl.
138         if let syn::Type::Reference(_) = binding.ast().ty {
139             synstructure::BindStyle::Move
140         } else {
141             synstructure::BindStyle::Ref
142         }
143     });
144
145     let encode_body = match s.variants() {
146         [_] => {
147             let mut field_idx = 0usize;
148             let encode_inner = s.each_variant(|vi| {
149                 vi.bindings()
150                     .iter()
151                     .map(|binding| {
152                         let bind_ident = &binding.binding;
153                         let field_name = binding
154                             .ast()
155                             .ident
156                             .as_ref()
157                             .map_or_else(|| field_idx.to_string(), |i| i.to_string());
158                         let first = field_idx == 0;
159                         let result = quote! {
160                             match ::rustc_serialize::Encoder::emit_struct_field(
161                                 __encoder,
162                                 #field_name,
163                                 #first,
164                                 |__encoder|
165                                 ::rustc_serialize::Encodable::<#encoder_ty>::encode(#bind_ident, __encoder),
166                             ) {
167                                 ::std::result::Result::Ok(()) => (),
168                                 ::std::result::Result::Err(__err)
169                                     => return ::std::result::Result::Err(__err),
170                             }
171                         };
172                         field_idx += 1;
173                         result
174                     })
175                     .collect::<TokenStream>()
176             });
177             let no_fields = field_idx == 0;
178             quote! {
179                 ::rustc_serialize::Encoder::emit_struct(__encoder, #no_fields, |__encoder| {
180                     ::std::result::Result::Ok(match *self { #encode_inner })
181                 })
182             }
183         }
184         _ => {
185             let mut variant_idx = 0usize;
186             let encode_inner = s.each_variant(|vi| {
187                 let variant_name = vi.ast().ident.to_string();
188                 let mut field_idx = 0usize;
189
190                 let encode_fields: TokenStream = vi
191                     .bindings()
192                     .iter()
193                     .map(|binding| {
194                         let bind_ident = &binding.binding;
195                         let first = field_idx == 0;
196                         let result = quote! {
197                             match ::rustc_serialize::Encoder::emit_enum_variant_arg(
198                                 __encoder,
199                                 #first,
200                                 |__encoder|
201                                 ::rustc_serialize::Encodable::<#encoder_ty>::encode(#bind_ident, __encoder),
202                             ) {
203                                 ::std::result::Result::Ok(()) => (),
204                                 ::std::result::Result::Err(__err)
205                                     => return ::std::result::Result::Err(__err),
206                             }
207                         };
208                         field_idx += 1;
209                         result
210                     })
211                     .collect();
212
213                 let result = if field_idx != 0 {
214                     quote! {
215                         ::rustc_serialize::Encoder::emit_enum_variant(
216                             __encoder,
217                             #variant_name,
218                             #variant_idx,
219                             #field_idx,
220                             |__encoder| { ::std::result::Result::Ok({ #encode_fields }) }
221                         )
222                     }
223                 } else {
224                     quote! {
225                         ::rustc_serialize::Encoder::emit_fieldless_enum_variant::<#variant_idx>(
226                             __encoder,
227                             #variant_name,
228                         )
229                     }
230                 };
231                 variant_idx += 1;
232                 result
233             });
234             quote! {
235                 ::rustc_serialize::Encoder::emit_enum(__encoder, |__encoder| {
236                     match *self {
237                         #encode_inner
238                     }
239                 })
240             }
241         }
242     };
243
244     let lints = if allow_unreachable_code {
245         quote! { #![allow(unreachable_code)] }
246     } else {
247         quote! {}
248     };
249
250     s.bound_impl(
251         quote!(::rustc_serialize::Encodable<#encoder_ty>),
252         quote! {
253             fn encode(
254                 &self,
255                 __encoder: &mut #encoder_ty,
256             ) -> ::std::result::Result<(), <#encoder_ty as ::rustc_serialize::Encoder>::Error> {
257                 #lints
258                 #encode_body
259             }
260         },
261     )
262 }