]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/newtype.rs
Use `#[derive]` instead of custom syntax in all `newtype_index`
[rust.git] / compiler / rustc_macros / src / newtype.rs
1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::parse::*;
4 use syn::*;
5
6 mod kw {
7     syn::custom_keyword!(DEBUG_FORMAT);
8     syn::custom_keyword!(MAX);
9     syn::custom_keyword!(ENCODABLE);
10     syn::custom_keyword!(custom);
11     syn::custom_keyword!(ORD_IMPL);
12 }
13
14 #[derive(Debug)]
15 enum DebugFormat {
16     // The user will provide a custom `Debug` impl, so we shouldn't generate
17     // one
18     Custom,
19     // Use the specified format string in the generated `Debug` impl
20     // By default, this is "{}"
21     Format(String),
22 }
23
24 // We parse the input and emit the output in a single step.
25 // This field stores the final macro output
26 struct Newtype(TokenStream);
27
28 impl Parse for Newtype {
29     fn parse(input: ParseStream<'_>) -> Result<Self> {
30         let attrs = input.call(Attribute::parse_outer)?;
31         let vis: Visibility = input.parse()?;
32         input.parse::<Token![struct]>()?;
33         let name: Ident = input.parse()?;
34
35         let body;
36         braced!(body in input);
37
38         // Any additional `#[derive]` macro paths to apply
39         let mut derive_paths: Vec<Path> = Vec::new();
40         let mut debug_format: Option<DebugFormat> = None;
41         let mut max = None;
42         let mut consts = Vec::new();
43         let mut encodable = true;
44         let mut ord = true;
45
46         // Parse an optional trailing comma
47         let try_comma = || -> Result<()> {
48             if body.lookahead1().peek(Token![,]) {
49                 body.parse::<Token![,]>()?;
50             }
51             Ok(())
52         };
53
54         if body.lookahead1().peek(Token![..]) {
55             body.parse::<Token![..]>()?;
56         } else {
57             loop {
58                 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
59                     body.parse::<kw::DEBUG_FORMAT>()?;
60                     body.parse::<Token![=]>()?;
61                     let new_debug_format = if body.lookahead1().peek(kw::custom) {
62                         body.parse::<kw::custom>()?;
63                         DebugFormat::Custom
64                     } else {
65                         let format_str: LitStr = body.parse()?;
66                         DebugFormat::Format(format_str.value())
67                     };
68                     try_comma()?;
69                     if let Some(old) = debug_format.replace(new_debug_format) {
70                         panic!("Specified multiple debug format options: {:?}", old);
71                     }
72                     continue;
73                 }
74                 if body.lookahead1().peek(kw::MAX) {
75                     body.parse::<kw::MAX>()?;
76                     body.parse::<Token![=]>()?;
77                     let val: Lit = body.parse()?;
78                     try_comma()?;
79                     if let Some(old) = max.replace(val) {
80                         panic!("Specified multiple MAX: {:?}", old);
81                     }
82                     continue;
83                 }
84                 if body.lookahead1().peek(kw::ENCODABLE) {
85                     body.parse::<kw::ENCODABLE>()?;
86                     body.parse::<Token![=]>()?;
87                     body.parse::<kw::custom>()?;
88                     try_comma()?;
89                     encodable = false;
90                     continue;
91                 }
92                 if body.lookahead1().peek(kw::ORD_IMPL) {
93                     body.parse::<kw::ORD_IMPL>()?;
94                     body.parse::<Token![=]>()?;
95                     body.parse::<kw::custom>()?;
96                     ord = false;
97                     continue;
98                 }
99
100                 // We've parsed everything that the user provided, so we're done
101                 if body.is_empty() {
102                     break;
103                 }
104
105                 // Otherwise, we are parsing a user-defined constant
106                 let const_attrs = body.call(Attribute::parse_outer)?;
107                 body.parse::<Token![const]>()?;
108                 let const_name: Ident = body.parse()?;
109                 body.parse::<Token![=]>()?;
110                 let const_val: Expr = body.parse()?;
111                 try_comma()?;
112                 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
113             }
114         }
115
116         let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
117         // shave off 256 indices at the end to allow space for packing these indices into enums
118         let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
119
120         let encodable_impls = if encodable {
121             quote! {
122                 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
123                     fn decode(d: &mut D) -> Self {
124                         Self::from_u32(d.read_u32())
125                     }
126                 }
127                 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
128                     fn encode(&self, e: &mut E) {
129                         e.emit_u32(self.private);
130                     }
131                 }
132             }
133         } else {
134             quote! {}
135         };
136
137         if ord {
138             derive_paths.push(parse_quote!(Ord));
139             derive_paths.push(parse_quote!(PartialOrd));
140         }
141
142         let step = if ord {
143             quote! {
144                 impl ::std::iter::Step for #name {
145                     #[inline]
146                     fn steps_between(start: &Self, end: &Self) -> Option<usize> {
147                         <usize as ::std::iter::Step>::steps_between(
148                             &Self::index(*start),
149                             &Self::index(*end),
150                         )
151                     }
152
153                     #[inline]
154                     fn forward_checked(start: Self, u: usize) -> Option<Self> {
155                         Self::index(start).checked_add(u).map(Self::from_usize)
156                     }
157
158                     #[inline]
159                     fn backward_checked(start: Self, u: usize) -> Option<Self> {
160                         Self::index(start).checked_sub(u).map(Self::from_usize)
161                     }
162                 }
163
164                 // Safety: The implementation of `Step` upholds all invariants.
165                 unsafe impl ::std::iter::TrustedStep for #name {}
166             }
167         } else {
168             quote! {}
169         };
170
171         let debug_impl = match debug_format {
172             DebugFormat::Custom => quote! {},
173             DebugFormat::Format(format) => {
174                 quote! {
175                     impl ::std::fmt::Debug for #name {
176                         fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
177                             write!(fmt, #format, self.as_u32())
178                         }
179                     }
180                 }
181             }
182         };
183         let spec_partial_eq_impl = if let Lit::Int(max) = &max {
184             if let Ok(max_val) = max.base10_parse::<u32>() {
185                 quote! {
186                     impl core::option::SpecOptionPartialEq for #name {
187                         #[inline]
188                         fn eq(l: &Option<Self>, r: &Option<Self>) -> bool {
189                             if #max_val < u32::MAX {
190                                 l.map(|i| i.private).unwrap_or(#max_val+1) == r.map(|i| i.private).unwrap_or(#max_val+1)
191                             } else {
192                                 match (l, r) {
193                                     (Some(l), Some(r)) => r == l,
194                                     (None, None) => true,
195                                     _ => false
196                                 }
197                             }
198                         }
199                     }
200                 }
201             } else {
202                 quote! {}
203             }
204         } else {
205             quote! {}
206         };
207
208         Ok(Self(quote! {
209             #(#attrs)*
210             #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
211             #[rustc_layout_scalar_valid_range_end(#max)]
212             #[rustc_pass_by_value]
213             #vis struct #name {
214                 private: u32,
215             }
216
217             #(#consts)*
218
219             impl #name {
220                 /// Maximum value the index can take, as a `u32`.
221                 #vis const MAX_AS_U32: u32  = #max;
222
223                 /// Maximum value the index can take.
224                 #vis const MAX: Self = Self::from_u32(#max);
225
226                 /// Creates a new index from a given `usize`.
227                 ///
228                 /// # Panics
229                 ///
230                 /// Will panic if `value` exceeds `MAX`.
231                 #[inline]
232                 #vis const fn from_usize(value: usize) -> Self {
233                     assert!(value <= (#max as usize));
234                     // SAFETY: We just checked that `value <= max`.
235                     unsafe {
236                         Self::from_u32_unchecked(value as u32)
237                     }
238                 }
239
240                 /// Creates a new index from a given `u32`.
241                 ///
242                 /// # Panics
243                 ///
244                 /// Will panic if `value` exceeds `MAX`.
245                 #[inline]
246                 #vis const fn from_u32(value: u32) -> Self {
247                     assert!(value <= #max);
248                     // SAFETY: We just checked that `value <= max`.
249                     unsafe {
250                         Self::from_u32_unchecked(value)
251                     }
252                 }
253
254                 /// Creates a new index from a given `u32`.
255                 ///
256                 /// # Safety
257                 ///
258                 /// The provided value must be less than or equal to the maximum value for the newtype.
259                 /// Providing a value outside this range is undefined due to layout restrictions.
260                 ///
261                 /// Prefer using `from_u32`.
262                 #[inline]
263                 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
264                     Self { private: value }
265                 }
266
267                 /// Extracts the value of this index as a `usize`.
268                 #[inline]
269                 #vis const fn index(self) -> usize {
270                     self.as_usize()
271                 }
272
273                 /// Extracts the value of this index as a `u32`.
274                 #[inline]
275                 #vis const fn as_u32(self) -> u32 {
276                     self.private
277                 }
278
279                 /// Extracts the value of this index as a `usize`.
280                 #[inline]
281                 #vis const fn as_usize(self) -> usize {
282                     self.as_u32() as usize
283                 }
284             }
285
286             impl std::ops::Add<usize> for #name {
287                 type Output = Self;
288
289                 fn add(self, other: usize) -> Self {
290                     Self::from_usize(self.index() + other)
291                 }
292             }
293
294             impl rustc_index::vec::Idx for #name {
295                 #[inline]
296                 fn new(value: usize) -> Self {
297                     Self::from_usize(value)
298                 }
299
300                 #[inline]
301                 fn index(self) -> usize {
302                     self.as_usize()
303                 }
304             }
305
306             #step
307
308             #spec_partial_eq_impl
309
310             impl From<#name> for u32 {
311                 #[inline]
312                 fn from(v: #name) -> u32 {
313                     v.as_u32()
314                 }
315             }
316
317             impl From<#name> for usize {
318                 #[inline]
319                 fn from(v: #name) -> usize {
320                     v.as_usize()
321                 }
322             }
323
324             impl From<usize> for #name {
325                 #[inline]
326                 fn from(value: usize) -> Self {
327                     Self::from_usize(value)
328                 }
329             }
330
331             impl From<u32> for #name {
332                 #[inline]
333                 fn from(value: u32) -> Self {
334                     Self::from_u32(value)
335                 }
336             }
337
338             #encodable_impls
339             #debug_impl
340         }))
341     }
342 }
343
344 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
345     let input = parse_macro_input!(input as Newtype);
346     input.0.into()
347 }