1 use proc_macro2::{Span, TokenStream};
4 use syn::punctuated::Punctuated;
8 syn::custom_keyword!(derive);
9 syn::custom_keyword!(DEBUG_FORMAT);
10 syn::custom_keyword!(MAX);
11 syn::custom_keyword!(ENCODABLE);
12 syn::custom_keyword!(custom);
13 syn::custom_keyword!(ORD_IMPL);
18 // The user will provide a custom `Debug` impl, so we shouldn't generate
21 // Use the specified format string in the generated `Debug` impl
22 // By default, this is "{}"
26 // We parse the input and emit the output in a single step.
27 // This field stores the final macro output
28 struct Newtype(TokenStream);
30 impl Parse for Newtype {
31 fn parse(input: ParseStream<'_>) -> Result<Self> {
32 let attrs = input.call(Attribute::parse_outer)?;
33 let vis: Visibility = input.parse()?;
34 input.parse::<Token![struct]>()?;
35 let name: Ident = input.parse()?;
38 braced!(body in input);
40 // Any additional `#[derive]` macro paths to apply
41 let mut derive_paths: Vec<Path> = Vec::new();
42 let mut debug_format: Option<DebugFormat> = None;
44 let mut consts = Vec::new();
45 let mut encodable = true;
48 // Parse an optional trailing comma
49 let try_comma = || -> Result<()> {
50 if body.lookahead1().peek(Token![,]) {
51 body.parse::<Token![,]>()?;
56 if body.lookahead1().peek(Token![..]) {
57 body.parse::<Token![..]>()?;
60 if body.lookahead1().peek(kw::derive) {
61 body.parse::<kw::derive>()?;
63 bracketed!(derives in body);
64 let derives: Punctuated<Path, Token![,]> =
65 derives.parse_terminated(Path::parse)?;
67 derive_paths.extend(derives);
70 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
71 body.parse::<kw::DEBUG_FORMAT>()?;
72 body.parse::<Token![=]>()?;
73 let new_debug_format = if body.lookahead1().peek(kw::custom) {
74 body.parse::<kw::custom>()?;
77 let format_str: LitStr = body.parse()?;
78 DebugFormat::Format(format_str.value())
81 if let Some(old) = debug_format.replace(new_debug_format) {
82 panic!("Specified multiple debug format options: {:?}", old);
86 if body.lookahead1().peek(kw::MAX) {
87 body.parse::<kw::MAX>()?;
88 body.parse::<Token![=]>()?;
89 let val: Lit = body.parse()?;
91 if let Some(old) = max.replace(val) {
92 panic!("Specified multiple MAX: {:?}", old);
96 if body.lookahead1().peek(kw::ENCODABLE) {
97 body.parse::<kw::ENCODABLE>()?;
98 body.parse::<Token![=]>()?;
99 body.parse::<kw::custom>()?;
104 if body.lookahead1().peek(kw::ORD_IMPL) {
105 body.parse::<kw::ORD_IMPL>()?;
106 body.parse::<Token![=]>()?;
107 body.parse::<kw::custom>()?;
112 // We've parsed everything that the user provided, so we're done
117 // Otherwise, we are parsing a user-defined constant
118 let const_attrs = body.call(Attribute::parse_outer)?;
119 body.parse::<Token![const]>()?;
120 let const_name: Ident = body.parse()?;
121 body.parse::<Token![=]>()?;
122 let const_val: Expr = body.parse()?;
124 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
128 let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
129 // shave off 256 indices at the end to allow space for packing these indices into enums
130 let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
132 let encodable_impls = if encodable {
134 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
135 fn decode(d: &mut D) -> Self {
136 Self::from_u32(d.read_u32())
139 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
140 fn encode(&self, e: &mut E) {
141 e.emit_u32(self.private);
150 derive_paths.push(parse_quote!(Ord));
151 derive_paths.push(parse_quote!(PartialOrd));
156 impl ::std::iter::Step for #name {
158 fn steps_between(start: &Self, end: &Self) -> Option<usize> {
159 <usize as ::std::iter::Step>::steps_between(
160 &Self::index(*start),
166 fn forward_checked(start: Self, u: usize) -> Option<Self> {
167 Self::index(start).checked_add(u).map(Self::from_usize)
171 fn backward_checked(start: Self, u: usize) -> Option<Self> {
172 Self::index(start).checked_sub(u).map(Self::from_usize)
176 // Safety: The implementation of `Step` upholds all invariants.
177 unsafe impl ::std::iter::TrustedStep for #name {}
183 let debug_impl = match debug_format {
184 DebugFormat::Custom => quote! {},
185 DebugFormat::Format(format) => {
187 impl ::std::fmt::Debug for #name {
188 fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
189 write!(fmt, #format, self.as_u32())
195 let spec_partial_eq_impl = if let Lit::Int(max) = &max {
196 if let Ok(max_val) = max.base10_parse::<u32>() {
198 impl core::option::SpecOptionPartialEq for #name {
200 fn eq(l: &Option<Self>, r: &Option<Self>) -> bool {
201 if #max_val < u32::MAX {
202 l.map(|i| i.private).unwrap_or(#max_val+1) == r.map(|i| i.private).unwrap_or(#max_val+1)
205 (Some(l), Some(r)) => r == l,
206 (None, None) => true,
222 #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
223 #[rustc_layout_scalar_valid_range_end(#max)]
224 #[rustc_pass_by_value]
232 /// Maximum value the index can take, as a `u32`.
233 #vis const MAX_AS_U32: u32 = #max;
235 /// Maximum value the index can take.
236 #vis const MAX: Self = Self::from_u32(#max);
238 /// Creates a new index from a given `usize`.
242 /// Will panic if `value` exceeds `MAX`.
244 #vis const fn from_usize(value: usize) -> Self {
245 assert!(value <= (#max as usize));
246 // SAFETY: We just checked that `value <= max`.
248 Self::from_u32_unchecked(value as u32)
252 /// Creates a new index from a given `u32`.
256 /// Will panic if `value` exceeds `MAX`.
258 #vis const fn from_u32(value: u32) -> Self {
259 assert!(value <= #max);
260 // SAFETY: We just checked that `value <= max`.
262 Self::from_u32_unchecked(value)
266 /// Creates a new index from a given `u32`.
270 /// The provided value must be less than or equal to the maximum value for the newtype.
271 /// Providing a value outside this range is undefined due to layout restrictions.
273 /// Prefer using `from_u32`.
275 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
276 Self { private: value }
279 /// Extracts the value of this index as a `usize`.
281 #vis const fn index(self) -> usize {
285 /// Extracts the value of this index as a `u32`.
287 #vis const fn as_u32(self) -> u32 {
291 /// Extracts the value of this index as a `usize`.
293 #vis const fn as_usize(self) -> usize {
294 self.as_u32() as usize
298 impl std::ops::Add<usize> for #name {
301 fn add(self, other: usize) -> Self {
302 Self::from_usize(self.index() + other)
306 impl rustc_index::vec::Idx for #name {
308 fn new(value: usize) -> Self {
309 Self::from_usize(value)
313 fn index(self) -> usize {
320 #spec_partial_eq_impl
322 impl From<#name> for u32 {
324 fn from(v: #name) -> u32 {
329 impl From<#name> for usize {
331 fn from(v: #name) -> usize {
336 impl From<usize> for #name {
338 fn from(value: usize) -> Self {
339 Self::from_usize(value)
343 impl From<u32> for #name {
345 fn from(value: u32) -> Self {
346 Self::from_u32(value)
356 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
357 let input = parse_macro_input!(input as Newtype);