1 use proc_macro2::{Span, TokenStream};
6 // We parse the input and emit the output in a single step.
7 // This field stores the final macro output
8 struct Newtype(TokenStream);
10 impl Parse for Newtype {
11 fn parse(input: ParseStream<'_>) -> Result<Self> {
12 let mut attrs = input.call(Attribute::parse_outer)?;
13 let vis: Visibility = input.parse()?;
14 input.parse::<Token![struct]>()?;
15 let name: Ident = input.parse()?;
18 braced!(body in input);
20 // Any additional `#[derive]` macro paths to apply
21 let mut derive_paths: Vec<Path> = Vec::new();
22 let mut debug_format: Option<Lit> = None;
24 let mut consts = Vec::new();
25 let mut encodable = true;
28 attrs.retain(|attr| match attr.path.get_ident() {
29 Some(ident) => match &*ident.to_string() {
30 "custom_encodable" => {
39 let Ok(Meta::NameValue(literal) )= attr.parse_meta() else {
40 panic!("#[max = NUMBER] attribute requires max value");
43 if let Some(old) = max.replace(literal.lit) {
44 panic!("Specified multiple max: {old:?}");
50 let Ok(Meta::NameValue(literal) )= attr.parse_meta() else {
51 panic!("#[debug_format = FMT] attribute requires a format");
54 if let Some(old) = debug_format.replace(literal.lit) {
55 panic!("Specified multiple debug format options: {old:?}");
66 // We've parsed everything that the user provided, so we're done
71 // Otherwise, we are parsing a user-defined constant
72 let const_attrs = body.call(Attribute::parse_outer)?;
73 body.parse::<Token![const]>()?;
74 let const_name: Ident = body.parse()?;
75 body.parse::<Token![=]>()?;
76 let const_val: Expr = body.parse()?;
77 body.parse::<Token![;]>()?;
78 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
82 debug_format.unwrap_or_else(|| Lit::Str(LitStr::new("{}", Span::call_site())));
84 // shave off 256 indices at the end to allow space for packing these indices into enums
85 let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
87 let encodable_impls = if encodable {
89 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
90 fn decode(d: &mut D) -> Self {
91 Self::from_u32(d.read_u32())
94 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
95 fn encode(&self, e: &mut E) {
96 e.emit_u32(self.private);
105 derive_paths.push(parse_quote!(Ord));
106 derive_paths.push(parse_quote!(PartialOrd));
111 impl ::std::iter::Step for #name {
113 fn steps_between(start: &Self, end: &Self) -> Option<usize> {
114 <usize as ::std::iter::Step>::steps_between(
115 &Self::index(*start),
121 fn forward_checked(start: Self, u: usize) -> Option<Self> {
122 Self::index(start).checked_add(u).map(Self::from_usize)
126 fn backward_checked(start: Self, u: usize) -> Option<Self> {
127 Self::index(start).checked_sub(u).map(Self::from_usize)
131 // Safety: The implementation of `Step` upholds all invariants.
132 unsafe impl ::std::iter::TrustedStep for #name {}
138 let debug_impl = quote! {
139 impl ::std::fmt::Debug for #name {
140 fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
141 write!(fmt, #debug_format, self.as_u32())
146 let spec_partial_eq_impl = if let Lit::Int(max) = &max {
147 if let Ok(max_val) = max.base10_parse::<u32>() {
149 impl core::option::SpecOptionPartialEq for #name {
151 fn eq(l: &Option<Self>, r: &Option<Self>) -> bool {
152 if #max_val < u32::MAX {
153 l.map(|i| i.private).unwrap_or(#max_val+1) == r.map(|i| i.private).unwrap_or(#max_val+1)
156 (Some(l), Some(r)) => r == l,
157 (None, None) => true,
173 #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
174 #[rustc_layout_scalar_valid_range_end(#max)]
175 #[rustc_pass_by_value]
183 /// Maximum value the index can take, as a `u32`.
184 #vis const MAX_AS_U32: u32 = #max;
186 /// Maximum value the index can take.
187 #vis const MAX: Self = Self::from_u32(#max);
189 /// Creates a new index from a given `usize`.
193 /// Will panic if `value` exceeds `MAX`.
195 #vis const fn from_usize(value: usize) -> Self {
196 assert!(value <= (#max as usize));
197 // SAFETY: We just checked that `value <= max`.
199 Self::from_u32_unchecked(value as u32)
203 /// Creates a new index from a given `u32`.
207 /// Will panic if `value` exceeds `MAX`.
209 #vis const fn from_u32(value: u32) -> Self {
210 assert!(value <= #max);
211 // SAFETY: We just checked that `value <= max`.
213 Self::from_u32_unchecked(value)
217 /// Creates a new index from a given `u32`.
221 /// The provided value must be less than or equal to the maximum value for the newtype.
222 /// Providing a value outside this range is undefined due to layout restrictions.
224 /// Prefer using `from_u32`.
226 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
227 Self { private: value }
230 /// Extracts the value of this index as a `usize`.
232 #vis const fn index(self) -> usize {
236 /// Extracts the value of this index as a `u32`.
238 #vis const fn as_u32(self) -> u32 {
242 /// Extracts the value of this index as a `usize`.
244 #vis const fn as_usize(self) -> usize {
245 self.as_u32() as usize
249 impl std::ops::Add<usize> for #name {
252 fn add(self, other: usize) -> Self {
253 Self::from_usize(self.index() + other)
257 impl rustc_index::vec::Idx for #name {
259 fn new(value: usize) -> Self {
260 Self::from_usize(value)
264 fn index(self) -> usize {
271 #spec_partial_eq_impl
273 impl From<#name> for u32 {
275 fn from(v: #name) -> u32 {
280 impl From<#name> for usize {
282 fn from(v: #name) -> usize {
287 impl From<usize> for #name {
289 fn from(value: usize) -> Self {
290 Self::from_usize(value)
294 impl From<u32> for #name {
296 fn from(value: u32) -> Self {
297 Self::from_u32(value)
307 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308 let input = parse_macro_input!(input as Newtype);