]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_macros/src/newtype.rs
Address review comments
[rust.git] / compiler / rustc_macros / src / newtype.rs
1 use proc_macro2::{Span, TokenStream};
2 use quote::quote;
3 use syn::parse::*;
4 use syn::punctuated::Punctuated;
5 use syn::*;
6
7 mod kw {
8     syn::custom_keyword!(derive);
9     syn::custom_keyword!(DEBUG_FORMAT);
10     syn::custom_keyword!(MAX);
11     syn::custom_keyword!(ENCODABLE);
12     syn::custom_keyword!(custom);
13 }
14
15 #[derive(Debug)]
16 enum DebugFormat {
17     // The user will provide a custom `Debug` impl, so we shouldn't generate
18     // one
19     Custom,
20     // Use the specified format string in the generated `Debug` impl
21     // By default, this is "{}"
22     Format(String),
23 }
24
25 // We parse the input and emit the output in a single step.
26 // This field stores the final macro output
27 struct Newtype(TokenStream);
28
29 impl Parse for Newtype {
30     fn parse(input: ParseStream<'_>) -> Result<Self> {
31         let attrs = input.call(Attribute::parse_outer)?;
32         let vis: Visibility = input.parse()?;
33         input.parse::<Token![struct]>()?;
34         let name: Ident = input.parse()?;
35
36         let body;
37         braced!(body in input);
38
39         // Any additional `#[derive]` macro paths to apply
40         let mut derive_paths: Vec<Path> = Vec::new();
41         let mut debug_format: Option<DebugFormat> = None;
42         let mut max = None;
43         let mut consts = Vec::new();
44         let mut encodable = true;
45
46         // Parse an optional trailing comma
47         let try_comma = || -> Result<()> {
48             if body.lookahead1().peek(Token![,]) {
49                 body.parse::<Token![,]>()?;
50             }
51             Ok(())
52         };
53
54         if body.lookahead1().peek(Token![..]) {
55             body.parse::<Token![..]>()?;
56         } else {
57             loop {
58                 if body.lookahead1().peek(kw::derive) {
59                     body.parse::<kw::derive>()?;
60                     let derives;
61                     bracketed!(derives in body);
62                     let derives: Punctuated<Path, Token![,]> =
63                         derives.parse_terminated(Path::parse)?;
64                     try_comma()?;
65                     derive_paths.extend(derives);
66                     continue;
67                 }
68                 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
69                     body.parse::<kw::DEBUG_FORMAT>()?;
70                     body.parse::<Token![=]>()?;
71                     let new_debug_format = if body.lookahead1().peek(kw::custom) {
72                         body.parse::<kw::custom>()?;
73                         DebugFormat::Custom
74                     } else {
75                         let format_str: LitStr = body.parse()?;
76                         DebugFormat::Format(format_str.value())
77                     };
78                     try_comma()?;
79                     if let Some(old) = debug_format.replace(new_debug_format) {
80                         panic!("Specified multiple debug format options: {:?}", old);
81                     }
82                     continue;
83                 }
84                 if body.lookahead1().peek(kw::MAX) {
85                     body.parse::<kw::MAX>()?;
86                     body.parse::<Token![=]>()?;
87                     let val: Lit = body.parse()?;
88                     try_comma()?;
89                     if let Some(old) = max.replace(val) {
90                         panic!("Specified multiple MAX: {:?}", old);
91                     }
92                     continue;
93                 }
94                 if body.lookahead1().peek(kw::ENCODABLE) {
95                     body.parse::<kw::ENCODABLE>()?;
96                     body.parse::<Token![=]>()?;
97                     body.parse::<kw::custom>()?;
98                     try_comma()?;
99                     encodable = false;
100                     continue;
101                 }
102
103                 // We've parsed everything that the user provided, so we're done
104                 if body.is_empty() {
105                     break;
106                 }
107
108                 // Otherwise, we are parsng a user-defined constant
109                 let const_attrs = body.call(Attribute::parse_outer)?;
110                 body.parse::<Token![const]>()?;
111                 let const_name: Ident = body.parse()?;
112                 body.parse::<Token![=]>()?;
113                 let const_val: Expr = body.parse()?;
114                 try_comma()?;
115                 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
116             }
117         }
118
119         let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
120         // shave off 256 indices at the end to allow space for packing these indices into enums
121         let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
122
123         let encodable_impls = if encodable {
124             quote! {
125                 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
126                     fn decode(d: &mut D) -> Self {
127                         Self::from_u32(d.read_u32())
128                     }
129                 }
130                 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
131                     fn encode(&self, e: &mut E) -> Result<(), E::Error> {
132                         e.emit_u32(self.private)
133                     }
134                 }
135             }
136         } else {
137             quote! {}
138         };
139
140         let debug_impl = match debug_format {
141             DebugFormat::Custom => quote! {},
142             DebugFormat::Format(format) => {
143                 quote! {
144                     impl ::std::fmt::Debug for #name {
145                         fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
146                             write!(fmt, #format, self.as_u32())
147                         }
148                     }
149                 }
150             }
151         };
152
153         Ok(Self(quote! {
154             #(#attrs)*
155             #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, #(#derive_paths),*)]
156             #[rustc_layout_scalar_valid_range_end(#max)]
157             #vis struct #name {
158                 private: u32,
159             }
160
161             #(#consts)*
162
163             impl #name {
164                 /// Maximum value the index can take, as a `u32`.
165                 #vis const MAX_AS_U32: u32  = #max;
166
167                 /// Maximum value the index can take.
168                 #vis const MAX: Self = Self::from_u32(#max);
169
170                 /// Creates a new index from a given `usize`.
171                 ///
172                 /// # Panics
173                 ///
174                 /// Will panic if `value` exceeds `MAX`.
175                 #[inline]
176                 #vis const fn from_usize(value: usize) -> Self {
177                     assert!(value <= (#max as usize));
178                     // SAFETY: We just checked that `value <= max`.
179                     unsafe {
180                         Self::from_u32_unchecked(value as u32)
181                     }
182                 }
183
184                 /// Creates a new index from a given `u32`.
185                 ///
186                 /// # Panics
187                 ///
188                 /// Will panic if `value` exceeds `MAX`.
189                 #[inline]
190                 #vis const fn from_u32(value: u32) -> Self {
191                     assert!(value <= #max);
192                     // SAFETY: We just checked that `value <= max`.
193                     unsafe {
194                         Self::from_u32_unchecked(value)
195                     }
196                 }
197
198                 /// Creates a new index from a given `u32`.
199                 ///
200                 /// # Safety
201                 ///
202                 /// The provided value must be less than or equal to the maximum value for the newtype.
203                 /// Providing a value outside this range is undefined due to layout restrictions.
204                 ///
205                 /// Prefer using `from_u32`.
206                 #[inline]
207                 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
208                     Self { private: value }
209                 }
210
211                 /// Extracts the value of this index as a `usize`.
212                 #[inline]
213                 #vis const fn index(self) -> usize {
214                     self.as_usize()
215                 }
216
217                 /// Extracts the value of this index as a `u32`.
218                 #[inline]
219                 #vis const fn as_u32(self) -> u32 {
220                     self.private
221                 }
222
223                 /// Extracts the value of this index as a `usize`.
224                 #[inline]
225                 #vis const fn as_usize(self) -> usize {
226                     self.as_u32() as usize
227                 }
228             }
229
230             impl std::ops::Add<usize> for #name {
231                 type Output = Self;
232
233                 fn add(self, other: usize) -> Self {
234                     Self::from_usize(self.index() + other)
235                 }
236             }
237
238             impl rustc_index::vec::Idx for #name {
239                 #[inline]
240                 fn new(value: usize) -> Self {
241                     Self::from_usize(value)
242                 }
243
244                 #[inline]
245                 fn index(self) -> usize {
246                     self.as_usize()
247                 }
248             }
249
250             impl ::std::iter::Step for #name {
251                 #[inline]
252                 fn steps_between(start: &Self, end: &Self) -> Option<usize> {
253                     <usize as ::std::iter::Step>::steps_between(
254                         &Self::index(*start),
255                         &Self::index(*end),
256                     )
257                 }
258
259                 #[inline]
260                 fn forward_checked(start: Self, u: usize) -> Option<Self> {
261                     Self::index(start).checked_add(u).map(Self::from_usize)
262                 }
263
264                 #[inline]
265                 fn backward_checked(start: Self, u: usize) -> Option<Self> {
266                     Self::index(start).checked_sub(u).map(Self::from_usize)
267                 }
268             }
269
270             // Safety: The implementation of `Step` upholds all invariants.
271             unsafe impl ::std::iter::TrustedStep for #name {}
272
273             impl From<#name> for u32 {
274                 #[inline]
275                 fn from(v: #name) -> u32 {
276                     v.as_u32()
277                 }
278             }
279
280             impl From<#name> for usize {
281                 #[inline]
282                 fn from(v: #name) -> usize {
283                     v.as_usize()
284                 }
285             }
286
287             impl From<usize> for #name {
288                 #[inline]
289                 fn from(value: usize) -> Self {
290                     Self::from_usize(value)
291                 }
292             }
293
294             impl From<u32> for #name {
295                 #[inline]
296                 fn from(value: u32) -> Self {
297                     Self::from_u32(value)
298                 }
299             }
300
301             #encodable_impls
302             #debug_impl
303         }))
304     }
305 }
306
307 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308     let input = parse_macro_input!(input as Newtype);
309     input.0.into()
310 }