]> git.lizzy.rs Git - rust.git/blob - config_proc_macro/src/item_enum.rs
Add doc comment
[rust.git] / config_proc_macro / src / item_enum.rs
1 use proc_macro2::TokenStream;
2 use quote::quote;
3
4 use crate::attrs::*;
5 use crate::utils::*;
6
7 type Variants = syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>;
8
9 /// Defines and implements `config_type` enum.
10 pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result<TokenStream> {
11     let syn::ItemEnum {
12         vis,
13         enum_token,
14         ident,
15         generics,
16         variants,
17         ..
18     } = em;
19
20     let mod_name_str = format!("__define_config_type_on_enum_{}", ident);
21     let mod_name = syn::Ident::new(&mod_name_str, ident.span());
22     let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,));
23
24     let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants);
25     let impl_from_str = impl_from_str(&em.ident, &em.variants);
26     let impl_serde = impl_serde(&em.ident, &em.variants);
27     let impl_deserialize = impl_deserialize(&em.ident, &em.variants);
28
29     Ok(quote! {
30         #[allow(non_snake_case)]
31         mod #mod_name {
32             #[derive(Debug, Copy, Clone, Eq, PartialEq)]
33             pub #enum_token #ident #generics { #variants }
34             #impl_doc_hint
35             #impl_from_str
36             #impl_serde
37             #impl_deserialize
38         }
39         #vis use #mod_name::#ident;
40     })
41 }
42
43 /// Remove attributes specific to `config_proc_macro` from enum variant fields.
44 fn process_variant(variant: &syn::Variant) -> TokenStream {
45     let metas = variant
46         .attrs
47         .iter()
48         .filter(|attr| !is_doc_hint(attr) && !is_config_value(attr));
49     let attrs = fold_quote(metas, |meta| quote!(#meta));
50     let syn::Variant { ident, fields, .. } = variant;
51     quote!(#attrs #ident #fields)
52 }
53
54 fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream {
55     let doc_hint = variants
56         .iter()
57         .map(doc_hint_of_variant)
58         .collect::<Vec<_>>()
59         .join("|");
60     let doc_hint = format!("[{}]", doc_hint);
61     quote! {
62         use crate::config::ConfigType;
63         impl ConfigType for #ident {
64             fn doc_hint() -> String {
65                 #doc_hint.to_owned()
66             }
67         }
68     }
69 }
70
71 fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream {
72     let vs = variants
73         .iter()
74         .filter(|v| is_unit(v))
75         .map(|v| (config_value_of_variant(v), &v.ident));
76     let if_patterns = fold_quote(vs, |(s, v)| {
77         quote! {
78             if #s.eq_ignore_ascii_case(s) {
79                 return Ok(#ident::#v);
80             }
81         }
82     });
83     quote! {
84         impl ::std::str::FromStr for #ident {
85             type Err = &'static str;
86
87             fn from_str(s: &str) -> Result<Self, Self::Err> {
88                 #if_patterns
89                 return Err("Bad variant");
90             }
91         }
92     }
93 }
94
95 fn doc_hint_of_variant(variant: &syn::Variant) -> String {
96     find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string())
97 }
98
99 fn config_value_of_variant(variant: &syn::Variant) -> String {
100     find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string())
101 }
102
103 fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream {
104     let arms = fold_quote(variants.iter(), |v| {
105         let v_ident = &v.ident;
106         let pattern = match v.fields {
107             syn::Fields::Named(..) => quote!(#ident::v_ident{..}),
108             syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)),
109             syn::Fields::Unit => quote!(#ident::#v_ident),
110         };
111         let option_value = config_value_of_variant(v);
112         quote! {
113             #pattern => serializer.serialize_str(&#option_value),
114         }
115     });
116
117     quote! {
118         impl ::serde::ser::Serialize for #ident {
119             fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
120             where
121                 S: ::serde::ser::Serializer,
122             {
123                 use serde::ser::Error;
124                 match self {
125                     #arms
126                     _ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))),
127                 }
128             }
129         }
130     }
131 }
132
133 // Currently only unit variants are supported.
134 fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
135     let supported_vs = variants.iter().filter(|v| is_unit(v));
136     let if_patterns = fold_quote(supported_vs, |v| {
137         let config_value = config_value_of_variant(v);
138         let variant_ident = &v.ident;
139         quote! {
140             if #config_value.eq_ignore_ascii_case(s) {
141                 return Ok(#ident::#variant_ident);
142             }
143         }
144     });
145
146     let supported_vs = variants.iter().filter(|v| is_unit(v));
147     let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,));
148
149     quote! {
150         impl<'de> serde::de::Deserialize<'de> for #ident {
151             fn deserialize<D>(d: D) -> Result<Self, D::Error>
152             where
153                 D: serde::Deserializer<'de>,
154             {
155                 use serde::de::{Error, Visitor};
156                 use std::marker::PhantomData;
157                 use std::fmt;
158                 struct StringOnly<T>(PhantomData<T>);
159                 impl<'de, T> Visitor<'de> for StringOnly<T>
160                 where T: serde::Deserializer<'de> {
161                     type Value = String;
162                     fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
163                         formatter.write_str("string")
164                     }
165                     fn visit_str<E>(self, value: &str) -> Result<String, E> {
166                         Ok(String::from(value))
167                     }
168                 }
169                 let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?;
170
171                 #if_patterns
172
173                 static ALLOWED: &'static[&str] = &[#allowed];
174                 Err(D::Error::unknown_variant(&s, ALLOWED))
175             }
176         }
177     }
178 }