1 use proc_macro2::TokenStream;
7 type Variants = syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>;
9 /// Defines and implements `config_type` enum.
10 pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result<TokenStream> {
20 let mod_name_str = format!("__define_config_type_on_enum_{}", ident);
21 let mod_name = syn::Ident::new(&mod_name_str, ident.span());
22 let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,));
24 let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants);
25 let impl_from_str = impl_from_str(&em.ident, &em.variants);
26 let impl_serde = impl_serde(&em.ident, &em.variants);
27 let impl_deserialize = impl_deserialize(&em.ident, &em.variants);
30 #[allow(non_snake_case)]
32 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
33 pub #enum_token #ident #generics { #variants }
39 #vis use #mod_name::#ident;
43 /// Remove attributes specific to `config_proc_macro` from enum variant fields.
44 fn process_variant(variant: &syn::Variant) -> TokenStream {
48 .filter(|attr| !is_doc_hint(attr) && !is_config_value(attr));
49 let attrs = fold_quote(metas, |meta| quote!(#meta));
50 let syn::Variant { ident, fields, .. } = variant;
51 quote!(#attrs #ident #fields)
54 fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream {
55 let doc_hint = variants
57 .map(doc_hint_of_variant)
60 let doc_hint = format!("[{}]", doc_hint);
62 use crate::config::ConfigType;
63 impl ConfigType for #ident {
64 fn doc_hint() -> String {
71 fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream {
74 .filter(|v| is_unit(v))
75 .map(|v| (config_value_of_variant(v), &v.ident));
76 let if_patterns = fold_quote(vs, |(s, v)| {
78 if #s.eq_ignore_ascii_case(s) {
79 return Ok(#ident::#v);
84 impl ::std::str::FromStr for #ident {
85 type Err = &'static str;
87 fn from_str(s: &str) -> Result<Self, Self::Err> {
89 return Err("Bad variant");
95 fn doc_hint_of_variant(variant: &syn::Variant) -> String {
96 find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string())
99 fn config_value_of_variant(variant: &syn::Variant) -> String {
100 find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string())
103 fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream {
104 let arms = fold_quote(variants.iter(), |v| {
105 let v_ident = &v.ident;
106 let pattern = match v.fields {
107 syn::Fields::Named(..) => quote!(#ident::v_ident{..}),
108 syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)),
109 syn::Fields::Unit => quote!(#ident::#v_ident),
111 let option_value = config_value_of_variant(v);
113 #pattern => serializer.serialize_str(&#option_value),
118 impl ::serde::ser::Serialize for #ident {
119 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
121 S: ::serde::ser::Serializer,
123 use serde::ser::Error;
126 _ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))),
133 // Currently only unit variants are supported.
134 fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
135 let supported_vs = variants.iter().filter(|v| is_unit(v));
136 let if_patterns = fold_quote(supported_vs, |v| {
137 let config_value = config_value_of_variant(v);
138 let variant_ident = &v.ident;
140 if #config_value.eq_ignore_ascii_case(s) {
141 return Ok(#ident::#variant_ident);
146 let supported_vs = variants.iter().filter(|v| is_unit(v));
147 let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,));
150 impl<'de> serde::de::Deserialize<'de> for #ident {
151 fn deserialize<D>(d: D) -> Result<Self, D::Error>
153 D: serde::Deserializer<'de>,
155 use serde::de::{Error, Visitor};
156 use std::marker::PhantomData;
158 struct StringOnly<T>(PhantomData<T>);
159 impl<'de, T> Visitor<'de> for StringOnly<T>
160 where T: serde::Deserializer<'de> {
162 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
163 formatter.write_str("string")
165 fn visit_str<E>(self, value: &str) -> Result<String, E> {
166 Ok(String::from(value))
169 let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?;
173 static ALLOWED: &'static[&str] = &[#allowed];
174 Err(D::Error::unknown_variant(&s, ALLOWED))