]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/newtype.rs
Make `#[debug_format]` an attribute in `newtype_index`
[rust.git] / compiler / rustc_macros / src / newtype.rs
1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::parse::*;
4 use syn::*;
5
6 // We parse the input and emit the output in a single step.
7 // This field stores the final macro output
8 struct Newtype(TokenStream);
9
10 impl Parse for Newtype {
11     fn parse(input: ParseStream<'_>) -> Result<Self> {
12         let mut attrs = input.call(Attribute::parse_outer)?;
13         let vis: Visibility = input.parse()?;
14         input.parse::<Token![struct]>()?;
15         let name: Ident = input.parse()?;
16
17         let body;
18         braced!(body in input);
19
20         // Any additional `#[derive]` macro paths to apply
21         let mut derive_paths: Vec<Path> = Vec::new();
22         let mut debug_format: Option<Lit> = None;
23         let mut max = None;
24         let mut consts = Vec::new();
25         let mut encodable = true;
26         let mut ord = true;
27
28         // Parse an optional trailing comma
29         let try_comma = || -> Result<()> {
30             if body.lookahead1().peek(Token![,]) {
31                 body.parse::<Token![,]>()?;
32             }
33             Ok(())
34         };
35
36         attrs.retain(|attr| match attr.path.get_ident() {
37             Some(ident) => match &*ident.to_string() {
38                 "custom_encodable" => {
39                     encodable = false;
40                     false
41                 }
42                 "no_ord_impl" => {
43                     ord = false;
44                     false
45                 }
46                 "max" => {
47                     let Ok(Meta::NameValue(literal) )= attr.parse_meta() else {
48                         panic!("#[max = NUMBER] attribute requires max value");
49                     };
50
51                     if let Some(old) = max.replace(literal.lit) {
52                         panic!("Specified multiple max: {:?}", old);
53                     }
54
55                     false
56                 }
57                 "debug_format" => {
58                     let Ok(Meta::NameValue(literal) )= attr.parse_meta() else {
59                         panic!("#[debug_format = FMT] attribute requires a format");
60                     };
61
62                     if let Some(old) = debug_format.replace(literal.lit) {
63                         panic!("Specified multiple debug format options: {:?}", old);
64                     }
65
66                     false
67                 }
68                 _ => true,
69             },
70             _ => true,
71         });
72
73         if body.lookahead1().peek(Token![..]) {
74             body.parse::<Token![..]>()?;
75         } else {
76             loop {
77                 // We've parsed everything that the user provided, so we're done
78                 if body.is_empty() {
79                     break;
80                 }
81
82                 // Otherwise, we are parsing a user-defined constant
83                 let const_attrs = body.call(Attribute::parse_outer)?;
84                 body.parse::<Token![const]>()?;
85                 let const_name: Ident = body.parse()?;
86                 body.parse::<Token![=]>()?;
87                 let const_val: Expr = body.parse()?;
88                 try_comma()?;
89                 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
90             }
91         }
92
93         let debug_format =
94             debug_format.unwrap_or_else(|| Lit::Str(LitStr::new("{}", Span::call_site())));
95
96         // shave off 256 indices at the end to allow space for packing these indices into enums
97         let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
98
99         let encodable_impls = if encodable {
100             quote! {
101                 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
102                     fn decode(d: &mut D) -> Self {
103                         Self::from_u32(d.read_u32())
104                     }
105                 }
106                 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
107                     fn encode(&self, e: &mut E) {
108                         e.emit_u32(self.private);
109                     }
110                 }
111             }
112         } else {
113             quote! {}
114         };
115
116         if ord {
117             derive_paths.push(parse_quote!(Ord));
118             derive_paths.push(parse_quote!(PartialOrd));
119         }
120
121         let step = if ord {
122             quote! {
123                 impl ::std::iter::Step for #name {
124                     #[inline]
125                     fn steps_between(start: &Self, end: &Self) -> Option<usize> {
126                         <usize as ::std::iter::Step>::steps_between(
127                             &Self::index(*start),
128                             &Self::index(*end),
129                         )
130                     }
131
132                     #[inline]
133                     fn forward_checked(start: Self, u: usize) -> Option<Self> {
134                         Self::index(start).checked_add(u).map(Self::from_usize)
135                     }
136
137                     #[inline]
138                     fn backward_checked(start: Self, u: usize) -> Option<Self> {
139                         Self::index(start).checked_sub(u).map(Self::from_usize)
140                     }
141                 }
142
143                 // Safety: The implementation of `Step` upholds all invariants.
144                 unsafe impl ::std::iter::TrustedStep for #name {}
145             }
146         } else {
147             quote! {}
148         };
149
150         let debug_impl = quote! {
151             impl ::std::fmt::Debug for #name {
152                 fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
153                     write!(fmt, #debug_format, self.as_u32())
154                 }
155             }
156         };
157
158         let spec_partial_eq_impl = if let Lit::Int(max) = &max {
159             if let Ok(max_val) = max.base10_parse::<u32>() {
160                 quote! {
161                     impl core::option::SpecOptionPartialEq for #name {
162                         #[inline]
163                         fn eq(l: &Option<Self>, r: &Option<Self>) -> bool {
164                             if #max_val < u32::MAX {
165                                 l.map(|i| i.private).unwrap_or(#max_val+1) == r.map(|i| i.private).unwrap_or(#max_val+1)
166                             } else {
167                                 match (l, r) {
168                                     (Some(l), Some(r)) => r == l,
169                                     (None, None) => true,
170                                     _ => false
171                                 }
172                             }
173                         }
174                     }
175                 }
176             } else {
177                 quote! {}
178             }
179         } else {
180             quote! {}
181         };
182
183         Ok(Self(quote! {
184             #(#attrs)*
185             #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
186             #[rustc_layout_scalar_valid_range_end(#max)]
187             #[rustc_pass_by_value]
188             #vis struct #name {
189                 private: u32,
190             }
191
192             #(#consts)*
193
194             impl #name {
195                 /// Maximum value the index can take, as a `u32`.
196                 #vis const MAX_AS_U32: u32  = #max;
197
198                 /// Maximum value the index can take.
199                 #vis const MAX: Self = Self::from_u32(#max);
200
201                 /// Creates a new index from a given `usize`.
202                 ///
203                 /// # Panics
204                 ///
205                 /// Will panic if `value` exceeds `MAX`.
206                 #[inline]
207                 #vis const fn from_usize(value: usize) -> Self {
208                     assert!(value <= (#max as usize));
209                     // SAFETY: We just checked that `value <= max`.
210                     unsafe {
211                         Self::from_u32_unchecked(value as u32)
212                     }
213                 }
214
215                 /// Creates a new index from a given `u32`.
216                 ///
217                 /// # Panics
218                 ///
219                 /// Will panic if `value` exceeds `MAX`.
220                 #[inline]
221                 #vis const fn from_u32(value: u32) -> Self {
222                     assert!(value <= #max);
223                     // SAFETY: We just checked that `value <= max`.
224                     unsafe {
225                         Self::from_u32_unchecked(value)
226                     }
227                 }
228
229                 /// Creates a new index from a given `u32`.
230                 ///
231                 /// # Safety
232                 ///
233                 /// The provided value must be less than or equal to the maximum value for the newtype.
234                 /// Providing a value outside this range is undefined due to layout restrictions.
235                 ///
236                 /// Prefer using `from_u32`.
237                 #[inline]
238                 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
239                     Self { private: value }
240                 }
241
242                 /// Extracts the value of this index as a `usize`.
243                 #[inline]
244                 #vis const fn index(self) -> usize {
245                     self.as_usize()
246                 }
247
248                 /// Extracts the value of this index as a `u32`.
249                 #[inline]
250                 #vis const fn as_u32(self) -> u32 {
251                     self.private
252                 }
253
254                 /// Extracts the value of this index as a `usize`.
255                 #[inline]
256                 #vis const fn as_usize(self) -> usize {
257                     self.as_u32() as usize
258                 }
259             }
260
261             impl std::ops::Add<usize> for #name {
262                 type Output = Self;
263
264                 fn add(self, other: usize) -> Self {
265                     Self::from_usize(self.index() + other)
266                 }
267             }
268
269             impl rustc_index::vec::Idx for #name {
270                 #[inline]
271                 fn new(value: usize) -> Self {
272                     Self::from_usize(value)
273                 }
274
275                 #[inline]
276                 fn index(self) -> usize {
277                     self.as_usize()
278                 }
279             }
280
281             #step
282
283             #spec_partial_eq_impl
284
285             impl From<#name> for u32 {
286                 #[inline]
287                 fn from(v: #name) -> u32 {
288                     v.as_u32()
289                 }
290             }
291
292             impl From<#name> for usize {
293                 #[inline]
294                 fn from(v: #name) -> usize {
295                     v.as_usize()
296                 }
297             }
298
299             impl From<usize> for #name {
300                 #[inline]
301                 fn from(value: usize) -> Self {
302                     Self::from_usize(value)
303                 }
304             }
305
306             impl From<u32> for #name {
307                 #[inline]
308                 fn from(value: u32) -> Self {
309                     Self::from_u32(value)
310                 }
311             }
312
313             #encodable_impls
314             #debug_impl
315         }))
316     }
317 }
318
319 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
320     let input = parse_macro_input!(input as Newtype);
321     input.0.into()
322 }