]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/newtype.rs
Rollup merge of #105623 - compiler-errors:generator-type-size-fix, r=Nilstrieb
[rust.git] / compiler / rustc_macros / src / newtype.rs
1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::parse::*;
4 use syn::punctuated::Punctuated;
5 use syn::*;
6
7 mod kw {
8     syn::custom_keyword!(derive);
9     syn::custom_keyword!(DEBUG_FORMAT);
10     syn::custom_keyword!(MAX);
11     syn::custom_keyword!(ENCODABLE);
12     syn::custom_keyword!(custom);
13     syn::custom_keyword!(ORD_IMPL);
14 }
15
16 #[derive(Debug)]
17 enum DebugFormat {
18     // The user will provide a custom `Debug` impl, so we shouldn't generate
19     // one
20     Custom,
21     // Use the specified format string in the generated `Debug` impl
22     // By default, this is "{}"
23     Format(String),
24 }
25
26 // We parse the input and emit the output in a single step.
27 // This field stores the final macro output
28 struct Newtype(TokenStream);
29
30 impl Parse for Newtype {
31     fn parse(input: ParseStream<'_>) -> Result<Self> {
32         let attrs = input.call(Attribute::parse_outer)?;
33         let vis: Visibility = input.parse()?;
34         input.parse::<Token![struct]>()?;
35         let name: Ident = input.parse()?;
36
37         let body;
38         braced!(body in input);
39
40         // Any additional `#[derive]` macro paths to apply
41         let mut derive_paths: Vec<Path> = Vec::new();
42         let mut debug_format: Option<DebugFormat> = None;
43         let mut max = None;
44         let mut consts = Vec::new();
45         let mut encodable = true;
46         let mut ord = true;
47
48         // Parse an optional trailing comma
49         let try_comma = || -> Result<()> {
50             if body.lookahead1().peek(Token![,]) {
51                 body.parse::<Token![,]>()?;
52             }
53             Ok(())
54         };
55
56         if body.lookahead1().peek(Token![..]) {
57             body.parse::<Token![..]>()?;
58         } else {
59             loop {
60                 if body.lookahead1().peek(kw::derive) {
61                     body.parse::<kw::derive>()?;
62                     let derives;
63                     bracketed!(derives in body);
64                     let derives: Punctuated<Path, Token![,]> =
65                         derives.parse_terminated(Path::parse)?;
66                     try_comma()?;
67                     derive_paths.extend(derives);
68                     continue;
69                 }
70                 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
71                     body.parse::<kw::DEBUG_FORMAT>()?;
72                     body.parse::<Token![=]>()?;
73                     let new_debug_format = if body.lookahead1().peek(kw::custom) {
74                         body.parse::<kw::custom>()?;
75                         DebugFormat::Custom
76                     } else {
77                         let format_str: LitStr = body.parse()?;
78                         DebugFormat::Format(format_str.value())
79                     };
80                     try_comma()?;
81                     if let Some(old) = debug_format.replace(new_debug_format) {
82                         panic!("Specified multiple debug format options: {:?}", old);
83                     }
84                     continue;
85                 }
86                 if body.lookahead1().peek(kw::MAX) {
87                     body.parse::<kw::MAX>()?;
88                     body.parse::<Token![=]>()?;
89                     let val: Lit = body.parse()?;
90                     try_comma()?;
91                     if let Some(old) = max.replace(val) {
92                         panic!("Specified multiple MAX: {:?}", old);
93                     }
94                     continue;
95                 }
96                 if body.lookahead1().peek(kw::ENCODABLE) {
97                     body.parse::<kw::ENCODABLE>()?;
98                     body.parse::<Token![=]>()?;
99                     body.parse::<kw::custom>()?;
100                     try_comma()?;
101                     encodable = false;
102                     continue;
103                 }
104                 if body.lookahead1().peek(kw::ORD_IMPL) {
105                     body.parse::<kw::ORD_IMPL>()?;
106                     body.parse::<Token![=]>()?;
107                     body.parse::<kw::custom>()?;
108                     ord = false;
109                     continue;
110                 }
111
112                 // We've parsed everything that the user provided, so we're done
113                 if body.is_empty() {
114                     break;
115                 }
116
117                 // Otherwise, we are parsing a user-defined constant
118                 let const_attrs = body.call(Attribute::parse_outer)?;
119                 body.parse::<Token![const]>()?;
120                 let const_name: Ident = body.parse()?;
121                 body.parse::<Token![=]>()?;
122                 let const_val: Expr = body.parse()?;
123                 try_comma()?;
124                 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
125             }
126         }
127
128         let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
129         // shave off 256 indices at the end to allow space for packing these indices into enums
130         let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
131
132         let encodable_impls = if encodable {
133             quote! {
134                 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
135                     fn decode(d: &mut D) -> Self {
136                         Self::from_u32(d.read_u32())
137                     }
138                 }
139                 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
140                     fn encode(&self, e: &mut E) {
141                         e.emit_u32(self.private);
142                     }
143                 }
144             }
145         } else {
146             quote! {}
147         };
148
149         if ord {
150             derive_paths.push(parse_quote!(Ord));
151             derive_paths.push(parse_quote!(PartialOrd));
152         }
153
154         let step = if ord {
155             quote! {
156                 impl ::std::iter::Step for #name {
157                     #[inline]
158                     fn steps_between(start: &Self, end: &Self) -> Option<usize> {
159                         <usize as ::std::iter::Step>::steps_between(
160                             &Self::index(*start),
161                             &Self::index(*end),
162                         )
163                     }
164
165                     #[inline]
166                     fn forward_checked(start: Self, u: usize) -> Option<Self> {
167                         Self::index(start).checked_add(u).map(Self::from_usize)
168                     }
169
170                     #[inline]
171                     fn backward_checked(start: Self, u: usize) -> Option<Self> {
172                         Self::index(start).checked_sub(u).map(Self::from_usize)
173                     }
174                 }
175
176                 // Safety: The implementation of `Step` upholds all invariants.
177                 unsafe impl ::std::iter::TrustedStep for #name {}
178             }
179         } else {
180             quote! {}
181         };
182
183         let debug_impl = match debug_format {
184             DebugFormat::Custom => quote! {},
185             DebugFormat::Format(format) => {
186                 quote! {
187                     impl ::std::fmt::Debug for #name {
188                         fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
189                             write!(fmt, #format, self.as_u32())
190                         }
191                     }
192                 }
193             }
194         };
195         let spec_partial_eq_impl = if let Lit::Int(max) = &max {
196             if let Ok(max_val) = max.base10_parse::<u32>() {
197                 quote! {
198                     impl core::option::SpecOptionPartialEq for #name {
199                         #[inline]
200                         fn eq(l: &Option<Self>, r: &Option<Self>) -> bool {
201                             if #max_val < u32::MAX {
202                                 l.map(|i| i.private).unwrap_or(#max_val+1) == r.map(|i| i.private).unwrap_or(#max_val+1)
203                             } else {
204                                 match (l, r) {
205                                     (Some(l), Some(r)) => r == l,
206                                     (None, None) => true,
207                                     _ => false
208                                 }
209                             }
210                         }
211                     }
212                 }
213             } else {
214                 quote! {}
215             }
216         } else {
217             quote! {}
218         };
219
220         Ok(Self(quote! {
221             #(#attrs)*
222             #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
223             #[rustc_layout_scalar_valid_range_end(#max)]
224             #[rustc_pass_by_value]
225             #vis struct #name {
226                 private: u32,
227             }
228
229             #(#consts)*
230
231             impl #name {
232                 /// Maximum value the index can take, as a `u32`.
233                 #vis const MAX_AS_U32: u32  = #max;
234
235                 /// Maximum value the index can take.
236                 #vis const MAX: Self = Self::from_u32(#max);
237
238                 /// Creates a new index from a given `usize`.
239                 ///
240                 /// # Panics
241                 ///
242                 /// Will panic if `value` exceeds `MAX`.
243                 #[inline]
244                 #vis const fn from_usize(value: usize) -> Self {
245                     assert!(value <= (#max as usize));
246                     // SAFETY: We just checked that `value <= max`.
247                     unsafe {
248                         Self::from_u32_unchecked(value as u32)
249                     }
250                 }
251
252                 /// Creates a new index from a given `u32`.
253                 ///
254                 /// # Panics
255                 ///
256                 /// Will panic if `value` exceeds `MAX`.
257                 #[inline]
258                 #vis const fn from_u32(value: u32) -> Self {
259                     assert!(value <= #max);
260                     // SAFETY: We just checked that `value <= max`.
261                     unsafe {
262                         Self::from_u32_unchecked(value)
263                     }
264                 }
265
266                 /// Creates a new index from a given `u32`.
267                 ///
268                 /// # Safety
269                 ///
270                 /// The provided value must be less than or equal to the maximum value for the newtype.
271                 /// Providing a value outside this range is undefined due to layout restrictions.
272                 ///
273                 /// Prefer using `from_u32`.
274                 #[inline]
275                 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
276                     Self { private: value }
277                 }
278
279                 /// Extracts the value of this index as a `usize`.
280                 #[inline]
281                 #vis const fn index(self) -> usize {
282                     self.as_usize()
283                 }
284
285                 /// Extracts the value of this index as a `u32`.
286                 #[inline]
287                 #vis const fn as_u32(self) -> u32 {
288                     self.private
289                 }
290
291                 /// Extracts the value of this index as a `usize`.
292                 #[inline]
293                 #vis const fn as_usize(self) -> usize {
294                     self.as_u32() as usize
295                 }
296             }
297
298             impl std::ops::Add<usize> for #name {
299                 type Output = Self;
300
301                 fn add(self, other: usize) -> Self {
302                     Self::from_usize(self.index() + other)
303                 }
304             }
305
306             impl rustc_index::vec::Idx for #name {
307                 #[inline]
308                 fn new(value: usize) -> Self {
309                     Self::from_usize(value)
310                 }
311
312                 #[inline]
313                 fn index(self) -> usize {
314                     self.as_usize()
315                 }
316             }
317
318             #step
319
320             #spec_partial_eq_impl
321
322             impl From<#name> for u32 {
323                 #[inline]
324                 fn from(v: #name) -> u32 {
325                     v.as_u32()
326                 }
327             }
328
329             impl From<#name> for usize {
330                 #[inline]
331                 fn from(v: #name) -> usize {
332                     v.as_usize()
333                 }
334             }
335
336             impl From<usize> for #name {
337                 #[inline]
338                 fn from(value: usize) -> Self {
339                     Self::from_usize(value)
340                 }
341             }
342
343             impl From<u32> for #name {
344                 #[inline]
345                 fn from(value: u32) -> Self {
346                     Self::from_u32(value)
347                 }
348             }
349
350             #encodable_impls
351             #debug_impl
352         }))
353     }
354 }
355
356 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
357     let input = parse_macro_input!(input as Newtype);
358     input.0.into()
359 }