]> git.lizzy.rs Git - generate-random.git/blob - derive-macro/src/handle_enum.rs
b2b14c0a5f0083a84af30748e8423c1bc3277073
[generate-random.git] / derive-macro / src / handle_enum.rs
1 use proc_macro2::{Ident, Literal, TokenStream};
2 use quote::quote;
3 use syn::{DataEnum, Variant};
4 use super::generate_fields;
5
6 fn variant_weight(variant: &Variant) -> Literal {
7     for attr in variant.attrs.iter() {
8         if attr.path.is_ident("weight") {
9             return attr.parse_args::<Literal>().expect("expected literal for `#[weight(...)]`")
10         }
11     }
12     Literal::u64_suffixed(1)
13 }
14
15 pub fn generate(name: &Ident, ty: DataEnum) -> TokenStream {
16     let mut variant_weights = ty.variants.into_iter()
17         .map(|variant| (variant_weight(&variant), variant));
18
19     let mut arms = TokenStream::new();
20     let mut total_weight = quote! { 0 };
21     if let Some((weight, variant)) = variant_weights.next() {
22         let variant_name = variant.ident;
23         let fields = generate_fields(variant.fields);
24         arms.extend(quote! {
25             let end = #weight;
26             if 0 <= value && value < end {
27                 return Self::#variant_name #fields
28             }
29         });
30         total_weight = quote! { #weight };
31         for (weight, variant) in variant_weights {
32             let variant_name = variant.ident;
33             let fields = generate_fields(variant.fields);
34             arms.extend(quote! {
35                 let start = end;
36                 let end = start + #weight;
37                 if start <= value && value < end {
38                     return Self::#variant_name #fields
39                 }
40             });
41             total_weight = quote! { #total_weight + #weight };
42         }
43     }
44
45     quote! {
46         impl generate_random::GenerateRandom for #name {
47             fn generate_random<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
48                 let total_weight = #total_weight;
49                 let value = rng.gen_range(0..total_weight);
50                 #arms
51                 unreachable!()
52             }
53         }
54     }
55 }