1 use proc_macro2::{Span, TokenStream};
4 use syn::punctuated::Punctuated;
8 syn::custom_keyword!(derive);
9 syn::custom_keyword!(DEBUG_FORMAT);
10 syn::custom_keyword!(MAX);
11 syn::custom_keyword!(ENCODABLE);
12 syn::custom_keyword!(custom);
17 // The user will provide a custom `Debug` impl, so we shouldn't generate
20 // Use the specified format string in the generated `Debug` impl
21 // By default, this is "{}"
25 // We parse the input and emit the output in a single step.
26 // This field stores the final macro output
27 struct Newtype(TokenStream);
29 impl Parse for Newtype {
30 fn parse(input: ParseStream<'_>) -> Result<Self> {
31 let attrs = input.call(Attribute::parse_outer)?;
32 let vis: Visibility = input.parse()?;
33 input.parse::<Token![struct]>()?;
34 let name: Ident = input.parse()?;
37 braced!(body in input);
39 // Any additional `#[derive]` macro paths to apply
40 let mut derive_paths: Vec<Path> = Vec::new();
41 let mut debug_format: Option<DebugFormat> = None;
43 let mut consts = Vec::new();
44 let mut encodable = true;
46 // Parse an optional trailing comma
47 let try_comma = || -> Result<()> {
48 if body.lookahead1().peek(Token![,]) {
49 body.parse::<Token![,]>()?;
54 if body.lookahead1().peek(Token![..]) {
55 body.parse::<Token![..]>()?;
58 if body.lookahead1().peek(kw::derive) {
59 body.parse::<kw::derive>()?;
61 bracketed!(derives in body);
62 let derives: Punctuated<Path, Token![,]> =
63 derives.parse_terminated(Path::parse)?;
65 derive_paths.extend(derives);
68 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
69 body.parse::<kw::DEBUG_FORMAT>()?;
70 body.parse::<Token![=]>()?;
71 let new_debug_format = if body.lookahead1().peek(kw::custom) {
72 body.parse::<kw::custom>()?;
75 let format_str: LitStr = body.parse()?;
76 DebugFormat::Format(format_str.value())
79 if let Some(old) = debug_format.replace(new_debug_format) {
80 panic!("Specified multiple debug format options: {:?}", old);
84 if body.lookahead1().peek(kw::MAX) {
85 body.parse::<kw::MAX>()?;
86 body.parse::<Token![=]>()?;
87 let val: Lit = body.parse()?;
89 if let Some(old) = max.replace(val) {
90 panic!("Specified multiple MAX: {:?}", old);
94 if body.lookahead1().peek(kw::ENCODABLE) {
95 body.parse::<kw::ENCODABLE>()?;
96 body.parse::<Token![=]>()?;
97 body.parse::<kw::custom>()?;
103 // We've parsed everything that the user provided, so we're done
108 // Otherwise, we are parsng a user-defined constant
109 let const_attrs = body.call(Attribute::parse_outer)?;
110 body.parse::<Token![const]>()?;
111 let const_name: Ident = body.parse()?;
112 body.parse::<Token![=]>()?;
113 let const_val: Expr = body.parse()?;
115 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
119 let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
120 // shave off 256 indices at the end to allow space for packing these indices into enums
121 let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
123 let encodable_impls = if encodable {
125 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
126 fn decode(d: &mut D) -> Self {
127 Self::from_u32(d.read_u32())
130 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
131 fn encode(&self, e: &mut E) -> Result<(), E::Error> {
132 e.emit_u32(self.private)
140 let debug_impl = match debug_format {
141 DebugFormat::Custom => quote! {},
142 DebugFormat::Format(format) => {
144 impl ::std::fmt::Debug for #name {
145 fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
146 write!(fmt, #format, self.as_u32())
155 #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, #(#derive_paths),*)]
156 #[rustc_layout_scalar_valid_range_end(#max)]
164 /// Maximum value the index can take, as a `u32`.
165 #vis const MAX_AS_U32: u32 = #max;
167 /// Maximum value the index can take.
168 #vis const MAX: Self = Self::from_u32(#max);
170 /// Creates a new index from a given `usize`.
174 /// Will panic if `value` exceeds `MAX`.
176 #vis const fn from_usize(value: usize) -> Self {
177 assert!(value <= (#max as usize));
178 // SAFETY: We just checked that `value <= max`.
180 Self::from_u32_unchecked(value as u32)
184 /// Creates a new index from a given `u32`.
188 /// Will panic if `value` exceeds `MAX`.
190 #vis const fn from_u32(value: u32) -> Self {
191 assert!(value <= #max);
192 // SAFETY: We just checked that `value <= max`.
194 Self::from_u32_unchecked(value)
198 /// Creates a new index from a given `u32`.
202 /// The provided value must be less than or equal to the maximum value for the newtype.
203 /// Providing a value outside this range is undefined due to layout restrictions.
205 /// Prefer using `from_u32`.
207 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
208 Self { private: value }
211 /// Extracts the value of this index as a `usize`.
213 #vis const fn index(self) -> usize {
217 /// Extracts the value of this index as a `u32`.
219 #vis const fn as_u32(self) -> u32 {
223 /// Extracts the value of this index as a `usize`.
225 #vis const fn as_usize(self) -> usize {
226 self.as_u32() as usize
230 impl std::ops::Add<usize> for #name {
233 fn add(self, other: usize) -> Self {
234 Self::from_usize(self.index() + other)
238 impl rustc_index::vec::Idx for #name {
240 fn new(value: usize) -> Self {
241 Self::from_usize(value)
245 fn index(self) -> usize {
250 impl ::std::iter::Step for #name {
252 fn steps_between(start: &Self, end: &Self) -> Option<usize> {
253 <usize as ::std::iter::Step>::steps_between(
254 &Self::index(*start),
260 fn forward_checked(start: Self, u: usize) -> Option<Self> {
261 Self::index(start).checked_add(u).map(Self::from_usize)
265 fn backward_checked(start: Self, u: usize) -> Option<Self> {
266 Self::index(start).checked_sub(u).map(Self::from_usize)
270 // Safety: The implementation of `Step` upholds all invariants.
271 unsafe impl ::std::iter::TrustedStep for #name {}
273 impl From<#name> for u32 {
275 fn from(v: #name) -> u32 {
280 impl From<#name> for usize {
282 fn from(v: #name) -> usize {
287 impl From<usize> for #name {
289 fn from(value: usize) -> Self {
290 Self::from_usize(value)
294 impl From<u32> for #name {
296 fn from(value: u32) -> Self {
297 Self::from_u32(value)
307 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308 let input = parse_macro_input!(input as Newtype);