]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/serialize.rs
Rollup merge of #96029 - IsakNyberg:error-messages-fix, r=Dylan-DPC
[rust.git] / compiler / rustc_macros / src / serialize.rs
1 use proc_macro2::TokenStream;
2 use quote::{quote, quote_spanned};
3 use syn::parse_quote;
4 use syn::spanned::Spanned;
5
6 pub fn type_decodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
7     let decoder_ty = quote! { __D };
8     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
9         s.add_impl_generic(parse_quote! { 'tcx });
10     }
11     s.add_impl_generic(parse_quote! {#decoder_ty: ::rustc_middle::ty::codec::TyDecoder<'tcx>});
12     s.add_bounds(synstructure::AddBounds::Generics);
13
14     decodable_body(s, decoder_ty)
15 }
16
17 pub fn meta_decodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
18     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
19         s.add_impl_generic(parse_quote! { 'tcx });
20     }
21     s.add_impl_generic(parse_quote! { '__a });
22     let decoder_ty = quote! { DecodeContext<'__a, 'tcx> };
23     s.add_bounds(synstructure::AddBounds::Generics);
24
25     decodable_body(s, decoder_ty)
26 }
27
28 pub fn decodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
29     let decoder_ty = quote! { __D };
30     s.add_impl_generic(parse_quote! {#decoder_ty: ::rustc_serialize::Decoder});
31     s.add_bounds(synstructure::AddBounds::Generics);
32
33     decodable_body(s, decoder_ty)
34 }
35
36 fn decodable_body(
37     s: synstructure::Structure<'_>,
38     decoder_ty: TokenStream,
39 ) -> proc_macro2::TokenStream {
40     if let syn::Data::Union(_) = s.ast().data {
41         panic!("cannot derive on union")
42     }
43     let ty_name = s.ast().ident.to_string();
44     let decode_body = match s.variants() {
45         [vi] => vi.construct(|field, _index| decode_field(field)),
46         variants => {
47             let match_inner: TokenStream = variants
48                 .iter()
49                 .enumerate()
50                 .map(|(idx, vi)| {
51                     let construct = vi.construct(|field, _index| decode_field(field));
52                     quote! { #idx => { #construct } }
53                 })
54                 .collect();
55             let message = format!(
56                 "invalid enum variant tag while decoding `{}`, expected 0..{}",
57                 ty_name,
58                 variants.len()
59             );
60             quote! {
61                 match ::rustc_serialize::Decoder::read_usize(__decoder) {
62                     #match_inner
63                     _ => panic!(#message),
64                 }
65             }
66         }
67     };
68
69     s.bound_impl(
70         quote!(::rustc_serialize::Decodable<#decoder_ty>),
71         quote! {
72             fn decode(__decoder: &mut #decoder_ty) -> Self {
73                 #decode_body
74             }
75         },
76     )
77 }
78
79 fn decode_field(field: &syn::Field) -> proc_macro2::TokenStream {
80     let field_span = field.ident.as_ref().map_or(field.ty.span(), |ident| ident.span());
81
82     let decode_inner_method = if let syn::Type::Reference(_) = field.ty {
83         quote! { ::rustc_middle::ty::codec::RefDecodable::decode }
84     } else {
85         quote! { ::rustc_serialize::Decodable::decode }
86     };
87     let __decoder = quote! { __decoder };
88     // Use the span of the field for the method call, so
89     // that backtraces will point to the field.
90     quote_spanned! {field_span=> #decode_inner_method(#__decoder) }
91 }
92
93 pub fn type_encodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
94     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
95         s.add_impl_generic(parse_quote! {'tcx});
96     }
97     let encoder_ty = quote! { __E };
98     s.add_impl_generic(parse_quote! {#encoder_ty: ::rustc_middle::ty::codec::TyEncoder<'tcx>});
99     s.add_bounds(synstructure::AddBounds::Generics);
100
101     encodable_body(s, encoder_ty, false)
102 }
103
104 pub fn meta_encodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
105     if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
106         s.add_impl_generic(parse_quote! {'tcx});
107     }
108     s.add_impl_generic(parse_quote! { '__a });
109     let encoder_ty = quote! { EncodeContext<'__a, 'tcx> };
110     s.add_bounds(synstructure::AddBounds::Generics);
111
112     encodable_body(s, encoder_ty, true)
113 }
114
115 pub fn encodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
116     let encoder_ty = quote! { __E };
117     s.add_impl_generic(parse_quote! { #encoder_ty: ::rustc_serialize::Encoder});
118     s.add_bounds(synstructure::AddBounds::Generics);
119
120     encodable_body(s, encoder_ty, false)
121 }
122
123 fn encodable_body(
124     mut s: synstructure::Structure<'_>,
125     encoder_ty: TokenStream,
126     allow_unreachable_code: bool,
127 ) -> proc_macro2::TokenStream {
128     if let syn::Data::Union(_) = s.ast().data {
129         panic!("cannot derive on union")
130     }
131
132     s.bind_with(|binding| {
133         // Handle the lack of a blanket reference impl.
134         if let syn::Type::Reference(_) = binding.ast().ty {
135             synstructure::BindStyle::Move
136         } else {
137             synstructure::BindStyle::Ref
138         }
139     });
140
141     let encode_body = match s.variants() {
142         [_] => {
143             let mut field_idx = 0usize;
144             let encode_inner = s.each_variant(|vi| {
145                 vi.bindings()
146                     .iter()
147                     .map(|binding| {
148                         let bind_ident = &binding.binding;
149                         let field_name = binding
150                             .ast()
151                             .ident
152                             .as_ref()
153                             .map_or_else(|| field_idx.to_string(), |i| i.to_string());
154                         let first = field_idx == 0;
155                         let result = quote! {
156                             match ::rustc_serialize::Encoder::emit_struct_field(
157                                 __encoder,
158                                 #field_name,
159                                 #first,
160                                 |__encoder|
161                                 ::rustc_serialize::Encodable::<#encoder_ty>::encode(#bind_ident, __encoder),
162                             ) {
163                                 ::std::result::Result::Ok(()) => (),
164                                 ::std::result::Result::Err(__err)
165                                     => return ::std::result::Result::Err(__err),
166                             }
167                         };
168                         field_idx += 1;
169                         result
170                     })
171                     .collect::<TokenStream>()
172             });
173             let no_fields = field_idx == 0;
174             quote! {
175                 ::rustc_serialize::Encoder::emit_struct(__encoder, #no_fields, |__encoder| {
176                     ::std::result::Result::Ok(match *self { #encode_inner })
177                 })
178             }
179         }
180         _ => {
181             let mut variant_idx = 0usize;
182             let encode_inner = s.each_variant(|vi| {
183                 let variant_name = vi.ast().ident.to_string();
184                 let mut field_idx = 0usize;
185
186                 let encode_fields: TokenStream = vi
187                     .bindings()
188                     .iter()
189                     .map(|binding| {
190                         let bind_ident = &binding.binding;
191                         let first = field_idx == 0;
192                         let result = quote! {
193                             match ::rustc_serialize::Encoder::emit_enum_variant_arg(
194                                 __encoder,
195                                 #first,
196                                 |__encoder|
197                                 ::rustc_serialize::Encodable::<#encoder_ty>::encode(#bind_ident, __encoder),
198                             ) {
199                                 ::std::result::Result::Ok(()) => (),
200                                 ::std::result::Result::Err(__err)
201                                     => return ::std::result::Result::Err(__err),
202                             }
203                         };
204                         field_idx += 1;
205                         result
206                     })
207                     .collect();
208
209                 let result = if field_idx != 0 {
210                     quote! {
211                         ::rustc_serialize::Encoder::emit_enum_variant(
212                             __encoder,
213                             #variant_name,
214                             #variant_idx,
215                             #field_idx,
216                             |__encoder| { ::std::result::Result::Ok({ #encode_fields }) }
217                         )
218                     }
219                 } else {
220                     quote! {
221                         ::rustc_serialize::Encoder::emit_fieldless_enum_variant::<#variant_idx>(
222                             __encoder,
223                             #variant_name,
224                         )
225                     }
226                 };
227                 variant_idx += 1;
228                 result
229             });
230             quote! {
231                 ::rustc_serialize::Encoder::emit_enum(__encoder, |__encoder| {
232                     match *self {
233                         #encode_inner
234                     }
235                 })
236             }
237         }
238     };
239
240     let lints = if allow_unreachable_code {
241         quote! { #![allow(unreachable_code)] }
242     } else {
243         quote! {}
244     };
245
246     s.bound_impl(
247         quote!(::rustc_serialize::Encodable<#encoder_ty>),
248         quote! {
249             fn encode(
250                 &self,
251                 __encoder: &mut #encoder_ty,
252             ) -> ::std::result::Result<(), <#encoder_ty as ::rustc_serialize::Encoder>::Error> {
253                 #lints
254                 #encode_body
255             }
256         },
257     )
258 }