]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/newtype.rs
Rollup merge of #107201 - compiler-errors:confusing-async-fn-note, r=estebank
[rust.git] / compiler / rustc_macros / src / newtype.rs
1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::parse::*;
4 use syn::*;
5
6 // We parse the input and emit the output in a single step.
7 // This field stores the final macro output
8 struct Newtype(TokenStream);
9
10 impl Parse for Newtype {
11     fn parse(input: ParseStream<'_>) -> Result<Self> {
12         let mut attrs = input.call(Attribute::parse_outer)?;
13         let vis: Visibility = input.parse()?;
14         input.parse::<Token![struct]>()?;
15         let name: Ident = input.parse()?;
16
17         let body;
18         braced!(body in input);
19
20         // Any additional `#[derive]` macro paths to apply
21         let mut derive_paths: Vec<Path> = Vec::new();
22         let mut debug_format: Option<Lit> = None;
23         let mut max = None;
24         let mut consts = Vec::new();
25         let mut encodable = true;
26         let mut ord = true;
27
28         attrs.retain(|attr| match attr.path.get_ident() {
29             Some(ident) => match &*ident.to_string() {
30                 "custom_encodable" => {
31                     encodable = false;
32                     false
33                 }
34                 "no_ord_impl" => {
35                     ord = false;
36                     false
37                 }
38                 "max" => {
39                     let Ok(Meta::NameValue(literal) )= attr.parse_meta() else {
40                         panic!("#[max = NUMBER] attribute requires max value");
41                     };
42
43                     if let Some(old) = max.replace(literal.lit) {
44                         panic!("Specified multiple max: {old:?}");
45                     }
46
47                     false
48                 }
49                 "debug_format" => {
50                     let Ok(Meta::NameValue(literal) )= attr.parse_meta() else {
51                         panic!("#[debug_format = FMT] attribute requires a format");
52                     };
53
54                     if let Some(old) = debug_format.replace(literal.lit) {
55                         panic!("Specified multiple debug format options: {old:?}");
56                     }
57
58                     false
59                 }
60                 _ => true,
61             },
62             _ => true,
63         });
64
65         loop {
66             // We've parsed everything that the user provided, so we're done
67             if body.is_empty() {
68                 break;
69             }
70
71             // Otherwise, we are parsing a user-defined constant
72             let const_attrs = body.call(Attribute::parse_outer)?;
73             body.parse::<Token![const]>()?;
74             let const_name: Ident = body.parse()?;
75             body.parse::<Token![=]>()?;
76             let const_val: Expr = body.parse()?;
77             body.parse::<Token![;]>()?;
78             consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
79         }
80
81         let debug_format =
82             debug_format.unwrap_or_else(|| Lit::Str(LitStr::new("{}", Span::call_site())));
83
84         // shave off 256 indices at the end to allow space for packing these indices into enums
85         let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
86
87         let encodable_impls = if encodable {
88             quote! {
89                 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
90                     fn decode(d: &mut D) -> Self {
91                         Self::from_u32(d.read_u32())
92                     }
93                 }
94                 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
95                     fn encode(&self, e: &mut E) {
96                         e.emit_u32(self.private);
97                     }
98                 }
99             }
100         } else {
101             quote! {}
102         };
103
104         if ord {
105             derive_paths.push(parse_quote!(Ord));
106             derive_paths.push(parse_quote!(PartialOrd));
107         }
108
109         let step = if ord {
110             quote! {
111                 impl ::std::iter::Step for #name {
112                     #[inline]
113                     fn steps_between(start: &Self, end: &Self) -> Option<usize> {
114                         <usize as ::std::iter::Step>::steps_between(
115                             &Self::index(*start),
116                             &Self::index(*end),
117                         )
118                     }
119
120                     #[inline]
121                     fn forward_checked(start: Self, u: usize) -> Option<Self> {
122                         Self::index(start).checked_add(u).map(Self::from_usize)
123                     }
124
125                     #[inline]
126                     fn backward_checked(start: Self, u: usize) -> Option<Self> {
127                         Self::index(start).checked_sub(u).map(Self::from_usize)
128                     }
129                 }
130
131                 // Safety: The implementation of `Step` upholds all invariants.
132                 unsafe impl ::std::iter::TrustedStep for #name {}
133             }
134         } else {
135             quote! {}
136         };
137
138         let debug_impl = quote! {
139             impl ::std::fmt::Debug for #name {
140                 fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
141                     write!(fmt, #debug_format, self.as_u32())
142                 }
143             }
144         };
145
146         let spec_partial_eq_impl = if let Lit::Int(max) = &max {
147             if let Ok(max_val) = max.base10_parse::<u32>() {
148                 quote! {
149                     impl core::option::SpecOptionPartialEq for #name {
150                         #[inline]
151                         fn eq(l: &Option<Self>, r: &Option<Self>) -> bool {
152                             if #max_val < u32::MAX {
153                                 l.map(|i| i.private).unwrap_or(#max_val+1) == r.map(|i| i.private).unwrap_or(#max_val+1)
154                             } else {
155                                 match (l, r) {
156                                     (Some(l), Some(r)) => r == l,
157                                     (None, None) => true,
158                                     _ => false
159                                 }
160                             }
161                         }
162                     }
163                 }
164             } else {
165                 quote! {}
166             }
167         } else {
168             quote! {}
169         };
170
171         Ok(Self(quote! {
172             #(#attrs)*
173             #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
174             #[rustc_layout_scalar_valid_range_end(#max)]
175             #[rustc_pass_by_value]
176             #vis struct #name {
177                 private: u32,
178             }
179
180             #(#consts)*
181
182             impl #name {
183                 /// Maximum value the index can take, as a `u32`.
184                 #vis const MAX_AS_U32: u32  = #max;
185
186                 /// Maximum value the index can take.
187                 #vis const MAX: Self = Self::from_u32(#max);
188
189                 /// Creates a new index from a given `usize`.
190                 ///
191                 /// # Panics
192                 ///
193                 /// Will panic if `value` exceeds `MAX`.
194                 #[inline]
195                 #vis const fn from_usize(value: usize) -> Self {
196                     assert!(value <= (#max as usize));
197                     // SAFETY: We just checked that `value <= max`.
198                     unsafe {
199                         Self::from_u32_unchecked(value as u32)
200                     }
201                 }
202
203                 /// Creates a new index from a given `u32`.
204                 ///
205                 /// # Panics
206                 ///
207                 /// Will panic if `value` exceeds `MAX`.
208                 #[inline]
209                 #vis const fn from_u32(value: u32) -> Self {
210                     assert!(value <= #max);
211                     // SAFETY: We just checked that `value <= max`.
212                     unsafe {
213                         Self::from_u32_unchecked(value)
214                     }
215                 }
216
217                 /// Creates a new index from a given `u32`.
218                 ///
219                 /// # Safety
220                 ///
221                 /// The provided value must be less than or equal to the maximum value for the newtype.
222                 /// Providing a value outside this range is undefined due to layout restrictions.
223                 ///
224                 /// Prefer using `from_u32`.
225                 #[inline]
226                 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
227                     Self { private: value }
228                 }
229
230                 /// Extracts the value of this index as a `usize`.
231                 #[inline]
232                 #vis const fn index(self) -> usize {
233                     self.as_usize()
234                 }
235
236                 /// Extracts the value of this index as a `u32`.
237                 #[inline]
238                 #vis const fn as_u32(self) -> u32 {
239                     self.private
240                 }
241
242                 /// Extracts the value of this index as a `usize`.
243                 #[inline]
244                 #vis const fn as_usize(self) -> usize {
245                     self.as_u32() as usize
246                 }
247             }
248
249             impl std::ops::Add<usize> for #name {
250                 type Output = Self;
251
252                 fn add(self, other: usize) -> Self {
253                     Self::from_usize(self.index() + other)
254                 }
255             }
256
257             impl rustc_index::vec::Idx for #name {
258                 #[inline]
259                 fn new(value: usize) -> Self {
260                     Self::from_usize(value)
261                 }
262
263                 #[inline]
264                 fn index(self) -> usize {
265                     self.as_usize()
266                 }
267             }
268
269             #step
270
271             #spec_partial_eq_impl
272
273             impl From<#name> for u32 {
274                 #[inline]
275                 fn from(v: #name) -> u32 {
276                     v.as_u32()
277                 }
278             }
279
280             impl From<#name> for usize {
281                 #[inline]
282                 fn from(v: #name) -> usize {
283                     v.as_usize()
284                 }
285             }
286
287             impl From<usize> for #name {
288                 #[inline]
289                 fn from(value: usize) -> Self {
290                     Self::from_usize(value)
291                 }
292             }
293
294             impl From<u32> for #name {
295                 #[inline]
296                 fn from(value: u32) -> Self {
297                     Self::from_u32(value)
298                 }
299             }
300
301             #encodable_impls
302             #debug_impl
303         }))
304     }
305 }
306
307 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308     let input = parse_macro_input!(input as Newtype);
309     input.0.into()
310 }