]> git.lizzy.rs Git - generate-random.git/blob - derive-macro/src/handle_enum.rs
cargo fmt
[generate-random.git] / derive-macro / src / handle_enum.rs
1 use super::generate_fields;
2 use proc_macro2::{Ident, Literal, TokenStream};
3 use quote::quote;
4 use syn::{DataEnum, Variant};
5
6 fn variant_weight(variant: &Variant) -> Literal {
7     for attr in variant.attrs.iter() {
8         if attr.path.is_ident("weight") {
9             return attr
10                 .parse_args::<Literal>()
11                 .expect("expected literal for `#[weight(...)]`");
12         }
13     }
14     Literal::u64_suffixed(1)
15 }
16
17 pub fn generate(name: &Ident, ty: DataEnum) -> TokenStream {
18     let mut variant_weights = ty
19         .variants
20         .into_iter()
21         .map(|variant| (variant_weight(&variant), variant));
22
23     let mut arms = TokenStream::new();
24     let mut total_weight = quote! { 0 };
25     if let Some((weight, variant)) = variant_weights.next() {
26         let variant_name = variant.ident;
27         let fields = generate_fields(variant.fields);
28         arms.extend(quote! {
29             let end = #weight;
30             if 0 <= value && value < end {
31                 return Self::#variant_name #fields
32             }
33         });
34         total_weight = quote! { #weight };
35         for (weight, variant) in variant_weights {
36             let variant_name = variant.ident;
37             let fields = generate_fields(variant.fields);
38             arms.extend(quote! {
39                 let start = end;
40                 let end = start + #weight;
41                 if start <= value && value < end {
42                     return Self::#variant_name #fields
43                 }
44             });
45             total_weight = quote! { #total_weight + #weight };
46         }
47     }
48
49     quote! {
50         impl generate_random::GenerateRandom for #name {
51             fn generate_random<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
52                 let total_weight = #total_weight;
53                 let value = rng.gen_range(0..total_weight);
54                 #arms
55                 unreachable!()
56             }
57         }
58     }
59 }