1 use proc_macro2::{Span, TokenStream};
4 use syn::punctuated::Punctuated;
8 syn::custom_keyword!(derive);
9 syn::custom_keyword!(DEBUG_FORMAT);
10 syn::custom_keyword!(MAX);
11 syn::custom_keyword!(ENCODABLE);
12 syn::custom_keyword!(custom);
13 syn::custom_keyword!(ORD_IMPL);
18 // The user will provide a custom `Debug` impl, so we shouldn't generate
21 // Use the specified format string in the generated `Debug` impl
22 // By default, this is "{}"
26 // We parse the input and emit the output in a single step.
27 // This field stores the final macro output
28 struct Newtype(TokenStream);
30 impl Parse for Newtype {
31 fn parse(input: ParseStream<'_>) -> Result<Self> {
32 let attrs = input.call(Attribute::parse_outer)?;
33 let vis: Visibility = input.parse()?;
34 input.parse::<Token![struct]>()?;
35 let name: Ident = input.parse()?;
38 braced!(body in input);
40 // Any additional `#[derive]` macro paths to apply
41 let mut derive_paths: Vec<Path> = Vec::new();
42 let mut debug_format: Option<DebugFormat> = None;
44 let mut consts = Vec::new();
45 let mut encodable = true;
48 // Parse an optional trailing comma
49 let try_comma = || -> Result<()> {
50 if body.lookahead1().peek(Token![,]) {
51 body.parse::<Token![,]>()?;
56 if body.lookahead1().peek(Token![..]) {
57 body.parse::<Token![..]>()?;
60 if body.lookahead1().peek(kw::derive) {
61 body.parse::<kw::derive>()?;
63 bracketed!(derives in body);
64 let derives: Punctuated<Path, Token![,]> =
65 derives.parse_terminated(Path::parse)?;
67 derive_paths.extend(derives);
70 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
71 body.parse::<kw::DEBUG_FORMAT>()?;
72 body.parse::<Token![=]>()?;
73 let new_debug_format = if body.lookahead1().peek(kw::custom) {
74 body.parse::<kw::custom>()?;
77 let format_str: LitStr = body.parse()?;
78 DebugFormat::Format(format_str.value())
81 if let Some(old) = debug_format.replace(new_debug_format) {
82 panic!("Specified multiple debug format options: {:?}", old);
86 if body.lookahead1().peek(kw::MAX) {
87 body.parse::<kw::MAX>()?;
88 body.parse::<Token![=]>()?;
89 let val: Lit = body.parse()?;
91 if let Some(old) = max.replace(val) {
92 panic!("Specified multiple MAX: {:?}", old);
96 if body.lookahead1().peek(kw::ENCODABLE) {
97 body.parse::<kw::ENCODABLE>()?;
98 body.parse::<Token![=]>()?;
99 body.parse::<kw::custom>()?;
104 if body.lookahead1().peek(kw::ORD_IMPL) {
105 body.parse::<kw::ORD_IMPL>()?;
106 body.parse::<Token![=]>()?;
107 body.parse::<kw::custom>()?;
112 // We've parsed everything that the user provided, so we're done
117 // Otherwise, we are parsing a user-defined constant
118 let const_attrs = body.call(Attribute::parse_outer)?;
119 body.parse::<Token![const]>()?;
120 let const_name: Ident = body.parse()?;
121 body.parse::<Token![=]>()?;
122 let const_val: Expr = body.parse()?;
124 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
128 let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
129 // shave off 256 indices at the end to allow space for packing these indices into enums
130 let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
132 let encodable_impls = if encodable {
134 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
135 fn decode(d: &mut D) -> Self {
136 Self::from_u32(d.read_u32())
139 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
140 fn encode(&self, e: &mut E) {
141 e.emit_u32(self.private);
150 derive_paths.push(parse_quote!(Ord));
151 derive_paths.push(parse_quote!(PartialOrd));
156 impl ::std::iter::Step for #name {
158 fn steps_between(start: &Self, end: &Self) -> Option<usize> {
159 <usize as ::std::iter::Step>::steps_between(
160 &Self::index(*start),
166 fn forward_checked(start: Self, u: usize) -> Option<Self> {
167 Self::index(start).checked_add(u).map(Self::from_usize)
171 fn backward_checked(start: Self, u: usize) -> Option<Self> {
172 Self::index(start).checked_sub(u).map(Self::from_usize)
176 // Safety: The implementation of `Step` upholds all invariants.
177 unsafe impl ::std::iter::TrustedStep for #name {}
183 let debug_impl = match debug_format {
184 DebugFormat::Custom => quote! {},
185 DebugFormat::Format(format) => {
187 impl ::std::fmt::Debug for #name {
188 fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
189 write!(fmt, #format, self.as_u32())
198 #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
199 #[rustc_layout_scalar_valid_range_end(#max)]
200 #[rustc_pass_by_value]
208 /// Maximum value the index can take, as a `u32`.
209 #vis const MAX_AS_U32: u32 = #max;
211 /// Maximum value the index can take.
212 #vis const MAX: Self = Self::from_u32(#max);
214 /// Creates a new index from a given `usize`.
218 /// Will panic if `value` exceeds `MAX`.
220 #vis const fn from_usize(value: usize) -> Self {
221 assert!(value <= (#max as usize));
222 // SAFETY: We just checked that `value <= max`.
224 Self::from_u32_unchecked(value as u32)
228 /// Creates a new index from a given `u32`.
232 /// Will panic if `value` exceeds `MAX`.
234 #vis const fn from_u32(value: u32) -> Self {
235 assert!(value <= #max);
236 // SAFETY: We just checked that `value <= max`.
238 Self::from_u32_unchecked(value)
242 /// Creates a new index from a given `u32`.
246 /// The provided value must be less than or equal to the maximum value for the newtype.
247 /// Providing a value outside this range is undefined due to layout restrictions.
249 /// Prefer using `from_u32`.
251 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
252 Self { private: value }
255 /// Extracts the value of this index as a `usize`.
257 #vis const fn index(self) -> usize {
261 /// Extracts the value of this index as a `u32`.
263 #vis const fn as_u32(self) -> u32 {
267 /// Extracts the value of this index as a `usize`.
269 #vis const fn as_usize(self) -> usize {
270 self.as_u32() as usize
274 impl std::ops::Add<usize> for #name {
277 fn add(self, other: usize) -> Self {
278 Self::from_usize(self.index() + other)
282 impl rustc_index::vec::Idx for #name {
284 fn new(value: usize) -> Self {
285 Self::from_usize(value)
289 fn index(self) -> usize {
296 impl From<#name> for u32 {
298 fn from(v: #name) -> u32 {
303 impl From<#name> for usize {
305 fn from(v: #name) -> usize {
310 impl From<usize> for #name {
312 fn from(value: usize) -> Self {
313 Self::from_usize(value)
317 impl From<u32> for #name {
319 fn from(value: u32) -> Self {
320 Self::from_u32(value)
330 pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
331 let input = parse_macro_input!(input as Newtype);