]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/newtype.rs
Make `#[max]` an attribute in `newtype_index`
[rust.git] / compiler / rustc_macros / src / newtype.rs
1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::parse::*;
4 use syn::*;
5
6 mod kw {
7     syn::custom_keyword!(DEBUG_FORMAT);
8     syn::custom_keyword!(MAX);
9     syn::custom_keyword!(custom);
10 }
11
12 #[derive(Debug)]
13 enum DebugFormat {
14     // The user will provide a custom `Debug` impl, so we shouldn't generate
15     // one
16     Custom,
17     // Use the specified format string in the generated `Debug` impl
18     // By default, this is "{}"
19     Format(String),
20 }
21
22 // We parse the input and emit the output in a single step.
23 // This field stores the final macro output
24 struct Newtype(TokenStream);
25
26 impl Parse for Newtype {
27     fn parse(input: ParseStream<'_>) -> Result<Self> {
28         let mut attrs = input.call(Attribute::parse_outer)?;
29         let vis: Visibility = input.parse()?;
30         input.parse::<Token![struct]>()?;
31         let name: Ident = input.parse()?;
32
33         let body;
34         braced!(body in input);
35
36         // Any additional `#[derive]` macro paths to apply
37         let mut derive_paths: Vec<Path> = Vec::new();
38         let mut debug_format: Option<DebugFormat> = None;
39         let mut max = None;
40         let mut consts = Vec::new();
41         let mut encodable = true;
42         let mut ord = true;
43
44         // Parse an optional trailing comma
45         let try_comma = || -> Result<()> {
46             if body.lookahead1().peek(Token![,]) {
47                 body.parse::<Token![,]>()?;
48             }
49             Ok(())
50         };
51
52         attrs.retain(|attr| match attr.path.get_ident() {
53             Some(ident) => match &*ident.to_string() {
54                 "custom_encodable" => {
55                     encodable = false;
56                     false
57                 }
58                 "no_ord_impl" => {
59                     ord = false;
60                     false
61                 }
62                 "max" => {
63                     let Ok(Meta::NameValue(literal) )= attr.parse_meta() else {
64                         panic!("#[max = NUMBER] attribute requires max value");
65                     };
66
67                     if let Some(old) = max.replace(literal.lit) {
68                         panic!("Specified multiple MAX: {:?}", old);
69                     }
70
71                     false
72                 }
73                 _ => true,
74             },
75             _ => true,
76         });
77
78         if body.lookahead1().peek(Token![..]) {
79             body.parse::<Token![..]>()?;
80         } else {
81             loop {
82                 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
83                     body.parse::<kw::DEBUG_FORMAT>()?;
84                     body.parse::<Token![=]>()?;
85                     let new_debug_format = if body.lookahead1().peek(kw::custom) {
86                         body.parse::<kw::custom>()?;
87                         DebugFormat::Custom
88                     } else {
89                         let format_str: LitStr = body.parse()?;
90                         DebugFormat::Format(format_str.value())
91                     };
92                     try_comma()?;
93                     if let Some(old) = debug_format.replace(new_debug_format) {
94                         panic!("Specified multiple debug format options: {:?}", old);
95                     }
96                     continue;
97                 }
98
99                 // We've parsed everything that the user provided, so we're done
100                 if body.is_empty() {
101                     break;
102                 }
103
104                 // Otherwise, we are parsing a user-defined constant
105                 let const_attrs = body.call(Attribute::parse_outer)?;
106                 body.parse::<Token![const]>()?;
107                 let const_name: Ident = body.parse()?;
108                 body.parse::<Token![=]>()?;
109                 let const_val: Expr = body.parse()?;
110                 try_comma()?;
111                 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
112             }
113         }
114
115         let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
116         // shave off 256 indices at the end to allow space for packing these indices into enums
117         let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
118
119         let encodable_impls = if encodable {
120             quote! {
121                 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
122                     fn decode(d: &mut D) -> Self {
123                         Self::from_u32(d.read_u32())
124                     }
125                 }
126                 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
127                     fn encode(&self, e: &mut E) {
128                         e.emit_u32(self.private);
129                     }
130                 }
131             }
132         } else {
133             quote! {}
134         };
135
136         if ord {
137             derive_paths.push(parse_quote!(Ord));
138             derive_paths.push(parse_quote!(PartialOrd));
139         }
140
141         let step = if ord {
142             quote! {
143                 impl ::std::iter::Step for #name {
144                     #[inline]
145                     fn steps_between(start: &Self, end: &Self) -> Option<usize> {
146                         <usize as ::std::iter::Step>::steps_between(
147                             &Self::index(*start),
148                             &Self::index(*end),
149                         )
150                     }
151
152                     #[inline]
153                     fn forward_checked(start: Self, u: usize) -> Option<Self> {
154                         Self::index(start).checked_add(u).map(Self::from_usize)
155                     }
156
157                     #[inline]
158                     fn backward_checked(start: Self, u: usize) -> Option<Self> {
159                         Self::index(start).checked_sub(u).map(Self::from_usize)
160                     }
161                 }
162
163                 // Safety: The implementation of `Step` upholds all invariants.
164                 unsafe impl ::std::iter::TrustedStep for #name {}
165             }
166         } else {
167             quote! {}
168         };
169
170         let debug_impl = match debug_format {
171             DebugFormat::Custom => quote! {},
172             DebugFormat::Format(format) => {
173                 quote! {
174                     impl ::std::fmt::Debug for #name {
175                         fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
176                             write!(fmt, #format, self.as_u32())
177                         }
178                     }
179                 }
180             }
181         };
182         let spec_partial_eq_impl = if let Lit::Int(max) = &max {
183             if let Ok(max_val) = max.base10_parse::<u32>() {
184                 quote! {
185                     impl core::option::SpecOptionPartialEq for #name {
186                         #[inline]
187                         fn eq(l: &Option<Self>, r: &Option<Self>) -> bool {
188                             if #max_val < u32::MAX {
189                                 l.map(|i| i.private).unwrap_or(#max_val+1) == r.map(|i| i.private).unwrap_or(#max_val+1)
190                             } else {
191                                 match (l, r) {
192                                     (Some(l), Some(r)) => r == l,
193                                     (None, None) => true,
194                                     _ => false
195                                 }
196                             }
197                         }
198                     }
199                 }
200             } else {
201                 quote! {}
202             }
203         } else {
204             quote! {}
205         };
206
207         Ok(Self(quote! {
208             #(#attrs)*
209             #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
210             #[rustc_layout_scalar_valid_range_end(#max)]
211             #[rustc_pass_by_value]
212             #vis struct #name {
213                 private: u32,
214             }
215
216             #(#consts)*
217
218             impl #name {
219                 /// Maximum value the index can take, as a `u32`.
220                 #vis const MAX_AS_U32: u32  = #max;
221
222                 /// Maximum value the index can take.
223                 #vis const MAX: Self = Self::from_u32(#max);
224
225                 /// Creates a new index from a given `usize`.
226                 ///
227                 /// # Panics
228                 ///
229                 /// Will panic if `value` exceeds `MAX`.
230                 #[inline]
231                 #vis const fn from_usize(value: usize) -> Self {
232                     assert!(value <= (#max as usize));
233                     // SAFETY: We just checked that `value <= max`.
234                     unsafe {
235                         Self::from_u32_unchecked(value as u32)
236                     }
237                 }
238
239                 /// Creates a new index from a given `u32`.
240                 ///
241                 /// # Panics
242                 ///
243                 /// Will panic if `value` exceeds `MAX`.
244                 #[inline]
245                 #vis const fn from_u32(value: u32) -> Self {
246                     assert!(value <= #max);
247                     // SAFETY: We just checked that `value <= max`.
248                     unsafe {
249                         Self::from_u32_unchecked(value)
250                     }
251                 }
252
253                 /// Creates a new index from a given `u32`.
254                 ///
255                 /// # Safety
256                 ///
257                 /// The provided value must be less than or equal to the maximum value for the newtype.
258                 /// Providing a value outside this range is undefined due to layout restrictions.
259                 ///
260                 /// Prefer using `from_u32`.
261                 #[inline]
262                 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
263                     Self { private: value }
264                 }
265
266                 /// Extracts the value of this index as a `usize`.
267                 #[inline]
268                 #vis const fn index(self) -> usize {
269                     self.as_usize()
270                 }
271
272                 /// Extracts the value of this index as a `u32`.
273                 #[inline]
274                 #vis const fn as_u32(self) -> u32 {
275                     self.private
276                 }
277
278                 /// Extracts the value of this index as a `usize`.
279                 #[inline]
280                 #vis const fn as_usize(self) -> usize {
281                     self.as_u32() as usize
282                 }
283             }
284
285             impl std::ops::Add<usize> for #name {
286                 type Output = Self;
287
288                 fn add(self, other: usize) -> Self {
289                     Self::from_usize(self.index() + other)
290                 }
291             }
292
293             impl rustc_index::vec::Idx for #name {
294                 #[inline]
295                 fn new(value: usize) -> Self {
296                     Self::from_usize(value)
297                 }
298
299                 #[inline]
300                 fn index(self) -> usize {
301                     self.as_usize()
302                 }
303             }
304
305             #step
306
307             #spec_partial_eq_impl
308
309             impl From<#name> for u32 {
310                 #[inline]
311                 fn from(v: #name) -> u32 {
312                     v.as_u32()
313                 }
314             }
315
316             impl From<#name> for usize {
317                 #[inline]
318                 fn from(v: #name) -> usize {
319                     v.as_usize()
320                 }
321             }
322
323             impl From<usize> for #name {
324                 #[inline]
325                 fn from(value: usize) -> Self {
326                     Self::from_usize(value)
327                 }
328             }
329
330             impl From<u32> for #name {
331                 #[inline]
332                 fn from(value: u32) -> Self {
333                     Self::from_u32(value)
334                 }
335             }
336
337             #encodable_impls
338             #debug_impl
339         }))
340     }
341 }
342
343 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
344     let input = parse_macro_input!(input as Newtype);
345     input.0.into()
346 }